未验证 提交 ea8e8ebd 编写于 作者: W wangxinxin08 提交者: GitHub

refine code to avoid some problem (#1772)

上级 aa16d88a
...@@ -6,7 +6,6 @@ TrainDataset: ...@@ -6,7 +6,6 @@ TrainDataset:
image_dir: train2017 image_dir: train2017
anno_path: annotations/instances_train2017.json anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco dataset_dir: dataset/coco
mixup_epoch: 250
EvalDataset: EvalDataset:
!COCODataSet !COCODataSet
......
...@@ -21,6 +21,7 @@ TrainReader: ...@@ -21,6 +21,7 @@ TrainReader:
batch_size: 8 batch_size: 8
shuffle: true shuffle: true
drop_last: true drop_last: true
mixup_epoch: 250
EvalReader: EvalReader:
......
...@@ -30,6 +30,7 @@ class Compose(object): ...@@ -30,6 +30,7 @@ class Compose(object):
if hasattr(op_cls, 'num_classes'): if hasattr(op_cls, 'num_classes'):
op_cls.num_classes = num_classes op_cls.num_classes = num_classes
# TODO: should be refined in the future
if op_cls in [ if op_cls in [
transform.Gt2YoloTargetOp, transform.Gt2YoloTarget transform.Gt2YoloTargetOp, transform.Gt2YoloTarget
]: ]:
...@@ -89,7 +90,8 @@ class BaseDataLoader(object): ...@@ -89,7 +90,8 @@ class BaseDataLoader(object):
drop_last=False, drop_last=False,
drop_empty=True, drop_empty=True,
num_classes=81, num_classes=81,
with_background=True): with_background=True,
**kwargs):
# out fields # out fields
self._fields = inputs_def['fields'] if inputs_def else None self._fields = inputs_def['fields'] if inputs_def else None
# sample transform # sample transform
...@@ -107,6 +109,7 @@ class BaseDataLoader(object): ...@@ -107,6 +109,7 @@ class BaseDataLoader(object):
self.shuffle = shuffle self.shuffle = shuffle
self.drop_last = drop_last self.drop_last = drop_last
self.with_background = with_background self.with_background = with_background
self.kwargs = kwargs
def __call__(self, def __call__(self,
dataset, dataset,
...@@ -120,6 +123,8 @@ class BaseDataLoader(object): ...@@ -120,6 +123,8 @@ class BaseDataLoader(object):
# get data # get data
self._dataset.set_out(self._sample_transforms, self._dataset.set_out(self._sample_transforms,
copy.deepcopy(self._fields)) copy.deepcopy(self._fields))
# set kwargs
self._dataset.set_kwargs(**self.kwargs)
# batch sampler # batch sampler
if batch_sampler is None: if batch_sampler is None:
self._batch_sampler = DistributedBatchSampler( self._batch_sampler = DistributedBatchSampler(
...@@ -154,10 +159,12 @@ class TrainReader(BaseDataLoader): ...@@ -154,10 +159,12 @@ class TrainReader(BaseDataLoader):
drop_last=True, drop_last=True,
drop_empty=True, drop_empty=True,
num_classes=81, num_classes=81,
with_background=True): with_background=True,
super(TrainReader, self).__init__( **kwargs):
inputs_def, sample_transforms, batch_transforms, batch_size, super(TrainReader, self).__init__(inputs_def, sample_transforms,
shuffle, drop_last, drop_empty, num_classes, with_background) batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
@register @register
...@@ -171,10 +178,12 @@ class EvalReader(BaseDataLoader): ...@@ -171,10 +178,12 @@ class EvalReader(BaseDataLoader):
drop_last=True, drop_last=True,
drop_empty=True, drop_empty=True,
num_classes=81, num_classes=81,
with_background=True): with_background=True,
super(EvalReader, self).__init__( **kwargs):
inputs_def, sample_transforms, batch_transforms, batch_size, super(EvalReader, self).__init__(inputs_def, sample_transforms,
shuffle, drop_last, drop_empty, num_classes, with_background) batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
@register @register
...@@ -188,7 +197,9 @@ class TestReader(BaseDataLoader): ...@@ -188,7 +197,9 @@ class TestReader(BaseDataLoader):
drop_last=False, drop_last=False,
drop_empty=True, drop_empty=True,
num_classes=81, num_classes=81,
with_background=True): with_background=True,
super(TestReader, self).__init__( **kwargs):
inputs_def, sample_transforms, batch_transforms, batch_size, super(TestReader, self).__init__(inputs_def, sample_transforms,
shuffle, drop_last, drop_empty, num_classes, with_background) batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
...@@ -28,18 +28,9 @@ class COCODataSet(DetDataset): ...@@ -28,18 +28,9 @@ class COCODataSet(DetDataset):
dataset_dir=None, dataset_dir=None,
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
mixup_epoch=-1,
cutmix_epoch=-1,
mosaic_epoch=-1,
sample_num=-1): sample_num=-1):
super(COCODataSet, self).__init__( super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
dataset_dir, sample_num)
image_dir,
anno_path,
sample_num,
mixup_epoch=mixup_epoch,
cutmix_epoch=cutmix_epoch,
mosaic_epoch=mosaic_epoch)
self.load_image_only = False self.load_image_only = False
self.load_semantic = False self.load_semantic = False
......
...@@ -33,9 +33,6 @@ class DetDataset(Dataset): ...@@ -33,9 +33,6 @@ class DetDataset(Dataset):
anno_path=None, anno_path=None,
sample_num=-1, sample_num=-1,
use_default_label=None, use_default_label=None,
mixup_epoch=-1,
cutmix_epoch=-1,
mosaic_epoch=-1,
**kwargs): **kwargs):
super(DetDataset, self).__init__() super(DetDataset, self).__init__()
self.dataset_dir = dataset_dir if dataset_dir is not None else '' self.dataset_dir = dataset_dir if dataset_dir is not None else ''
...@@ -44,9 +41,6 @@ class DetDataset(Dataset): ...@@ -44,9 +41,6 @@ class DetDataset(Dataset):
self.sample_num = sample_num self.sample_num = sample_num
self.use_default_label = use_default_label self.use_default_label = use_default_label
self.epoch = 0 self.epoch = 0
self.mixup_epoch = mixup_epoch
self.cutmix_epoch = cutmix_epoch
self.mosaic_epoch = mosaic_epoch
def __len__(self, ): def __len__(self, ):
return len(self.roidbs) return len(self.roidbs)
...@@ -77,6 +71,11 @@ class DetDataset(Dataset): ...@@ -77,6 +71,11 @@ class DetDataset(Dataset):
out[k] = roidb[k] out[k] = roidb[k]
return out.values() return out.values()
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
def set_out(self, sample_transform, fields): def set_out(self, sample_transform, fields):
self.transform = sample_transform self.transform = sample_transform
self.fields = fields self.fields = fields
......
...@@ -4,7 +4,7 @@ import paddle.nn.functional as F ...@@ -4,7 +4,7 @@ import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ppdet.modeling.ops import BatchNorm from ppdet.modeling.ops import batch_norm
__all__ = ['DarkNet', 'ConvBNLayer'] __all__ = ['DarkNet', 'ConvBNLayer']
...@@ -31,7 +31,7 @@ class ConvBNLayer(nn.Layer): ...@@ -31,7 +31,7 @@ class ConvBNLayer(nn.Layer):
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + '.conv.weights'), weight_attr=ParamAttr(name=name + '.conv.weights'),
bias_attr=False) bias_attr=False)
self.batch_norm = BatchNorm(ch_out, norm_type=norm_type, name=name) self.batch_norm = batch_norm(ch_out, norm_type=norm_type, name=name)
self.act = act self.act = act
def forward(self, inputs): def forward(self, inputs):
......
...@@ -29,40 +29,26 @@ import numpy as np ...@@ -29,40 +29,26 @@ import numpy as np
from functools import reduce from functools import reduce
__all__ = [ __all__ = [
'roi_pool', 'roi_pool', 'roi_align', 'prior_box', 'anchor_generator',
'roi_align', 'generate_proposals', 'iou_similarity', 'box_coder', 'yolo_box',
'prior_box', 'multiclass_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals',
'anchor_generator', 'matrix_nms', 'batch_norm'
'generate_proposals',
'iou_similarity',
'box_coder',
'yolo_box',
'multiclass_nms',
'distribute_fpn_proposals',
'collect_fpn_proposals',
'matrix_nms',
'BatchNorm',
] ]
class BatchNorm(nn.Layer): def batch_norm(ch, norm_type='bn', name=None):
def __init__(self, ch, norm_type='bn', name=None): bn_name = name + '.bn'
super(BatchNorm, self).__init__() if norm_type == 'sync_bn':
bn_name = name + '.bn' batch_norm = nn.SyncBatchNorm
if norm_type == 'sync_bn': else:
batch_norm = nn.SyncBatchNorm batch_norm = nn.BatchNorm2D
else:
batch_norm = nn.BatchNorm2D return batch_norm(
ch,
self.batch_norm = batch_norm( weight_attr=ParamAttr(
ch, name=bn_name + '.scale', regularizer=L2Decay(0.)),
weight_attr=ParamAttr( bias_attr=ParamAttr(
name=bn_name + '.scale', regularizer=L2Decay(0.)), name=bn_name + '.offset', regularizer=L2Decay(0.)))
bias_attr=ParamAttr(
name=bn_name + '.offset', regularizer=L2Decay(0.)))
def forward(self, x):
return self.batch_norm(x)
def roi_pool(input, def roi_pool(input,
......
...@@ -156,7 +156,7 @@ def run(FLAGS, cfg, place): ...@@ -156,7 +156,7 @@ def run(FLAGS, cfg, place):
start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch'] start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch']
for epoch_id in range(int(cfg.epoch)): for epoch_id in range(int(cfg.epoch)):
cur_eid = epoch_id + start_epoch cur_eid = epoch_id + start_epoch
train_loader.dataset.epoch = epoch_id train_loader.dataset.epoch = cur_eid
for iter_id, data in enumerate(train_loader): for iter_id, data in enumerate(train_loader):
start_time = end_time start_time = end_time
end_time = time.time() end_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册