From 0e6468c72494fa86d261f6bc3c89441222ba981b Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 25 Mar 2021 22:06:35 +0800 Subject: [PATCH] refine trainer (#2412) * refine trainer --- ppdet/data/reader.py | 1 + ppdet/data/source/__init__.py | 2 ++ ppdet/{metrics => data/source}/category.py | 0 ppdet/data/source/coco.py | 10 +++++----- ppdet/data/source/dataset.py | 4 ++++ ppdet/engine/export_utils.py | 2 +- ppdet/engine/trainer.py | 15 ++++++++++++--- ppdet/metrics/__init__.py | 6 +----- ppdet/metrics/metrics.py | 6 ++++-- 9 files changed, 30 insertions(+), 16 deletions(-) rename ppdet/{metrics => data/source}/category.py (100%) diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 126051734..294ae95c6 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -184,6 +184,7 @@ class BaseDataLoader(object): batch_sampler=None, return_list=False): self.dataset = dataset + self.dataset.check_or_download_dataset() self.dataset.parse_dataset() # get data self.dataset.set_transform(self._sample_transforms) diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index 60c205d14..d1d2561d5 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -16,7 +16,9 @@ from . import coco # TODO add voc and widerface dataset from . import voc #from . import widerface +from . import category from .coco import * from .voc import * #from .widerface import * +from .category import * diff --git a/ppdet/metrics/category.py b/ppdet/data/source/category.py similarity index 100% rename from ppdet/metrics/category.py rename to ppdet/data/source/category.py diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py index 387229136..18625cfb5 100644 --- a/ppdet/data/source/coco.py +++ b/ppdet/data/source/coco.py @@ -49,10 +49,10 @@ class COCODataSet(DetDataset): records = [] ct = 0 - catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) - cname2cid = dict({ + self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) + self.cname2cid = dict({ coco.loadCats(catid)[0]['name']: clsid - for catid, clsid in catid2clsid.items() + for catid, clsid in self.catid2clsid.items() }) if 'annotations' not in coco.dataset: @@ -119,7 +119,7 @@ class COCODataSet(DetDataset): has_segmentation = False for i, box in enumerate(bboxes): catid = box['category_id'] - gt_class[i][0] = catid2clsid[catid] + gt_class[i][0] = self.catid2clsid[catid] gt_bbox[i, :] = box['clean_bbox'] is_crowd[i][0] = box['iscrowd'] # check RLE format @@ -163,4 +163,4 @@ class COCODataSet(DetDataset): break assert len(records) > 0, 'not found any coco record in %s' % (anno_path) logger.debug('{} samples in file {}'.format(ct, anno_path)) - self.roidbs, self.cname2cid = records, cname2cid + self.roidbs = records diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py index 429cdc7a5..76d41f05d 100644 --- a/ppdet/data/source/dataset.py +++ b/ppdet/data/source/dataset.py @@ -67,6 +67,10 @@ class DetDataset(Dataset): return self.transform(roidb) + def check_or_download_dataset(self): + self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path, + self.image_dir) + def set_kwargs(self, **kwargs): self.mixup_epoch = kwargs.get('mixup_epoch', -1) self.cutmix_epoch = kwargs.get('cutmix_epoch', -1) diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index c1b140615..bc3092a5a 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -20,7 +20,7 @@ import os import yaml from collections import OrderedDict -from ppdet.metrics import get_categories +from ppdet.data.source.category import get_categories from ppdet.utils.logger import setup_logger logger = setup_logger('ppdet.engine') diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 57c16e4b8..37a244d61 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -31,7 +31,8 @@ from paddle.static import InputSpec from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results -from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_categories, get_infer_results +from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results +from ppdet.data.source.category import get_categories import ppdet.utils.stats as stats from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter @@ -116,8 +117,8 @@ class Trainer(object): self._callbacks = [] self._compose_callback = None - def _init_metrics(self): - if self.mode == 'test': + def _init_metrics(self, validate=False): + if self.mode == 'test' or (self.mode == 'train' and not validate): self._metrics = [] return classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False @@ -126,9 +127,12 @@ class Trainer(object): bias = self.cfg['bias'] if 'bias' in self.cfg else 0 output_eval = self.cfg['output_eval'] \ if 'output_eval' in self.cfg else None + clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \ + if self.mode == 'eval' else None self._metrics = [ COCOMetric( anno_file=self.dataset.get_anno(), + clsid2catid=clsid2catid, classwise=classwise, output_eval=output_eval, bias=bias) @@ -186,6 +190,11 @@ class Trainer(object): def train(self, validate=False): assert self.mode == 'train', "Model not in 'train' mode" + # if validation in training is enabled, metrics should be re-init + if validate: + self._init_metrics(validate=validate) + self._reset_metrics() + model = self.model if self.cfg.fleet: model = fleet.distributed_model(model) diff --git a/ppdet/metrics/__init__.py b/ppdet/metrics/__init__.py index fb7add57c..460b12dea 100644 --- a/ppdet/metrics/__init__.py +++ b/ppdet/metrics/__init__.py @@ -15,8 +15,4 @@ from . import metrics from .metrics import * -from . import category -from .category import * - -__all__ = metrics.__all__ \ - + category.__all__ +__all__ = metrics.__all__ diff --git a/ppdet/metrics/metrics.py b/ppdet/metrics/metrics.py index efd34c21b..cae3a2e09 100644 --- a/ppdet/metrics/metrics.py +++ b/ppdet/metrics/metrics.py @@ -22,10 +22,10 @@ import json import paddle import numpy as np -from .category import get_categories from .map_utils import prune_zero_padding, DetectionMAP from .coco_utils import get_infer_results, cocoapi_eval from .widerface_utils import face_eval_run +from ppdet.data.source.category import get_categories from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) @@ -62,7 +62,9 @@ class COCOMetric(Metric): assert os.path.isfile(anno_file), \ "anno_file {} not a file".format(anno_file) self.anno_file = anno_file - self.clsid2catid, self.catid2name = get_categories('COCO', anno_file) + self.clsid2catid = kwargs.get('clsid2catid', None) + if self.clsid2catid is None: + self.clsid2catid, _ = get_categories('COCO', anno_file) self.classwise = kwargs.get('classwise', False) self.output_eval = kwargs.get('output_eval', None) # TODO: bias should be unified -- GitLab