未验证 提交 ea8e8ebd 编写于 作者: W wangxinxin08 提交者: GitHub

refine code to avoid some problem (#1772)

上级 aa16d88a
......@@ -6,7 +6,6 @@ TrainDataset:
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
mixup_epoch: 250
EvalDataset:
!COCODataSet
......
......@@ -21,6 +21,7 @@ TrainReader:
batch_size: 8
shuffle: true
drop_last: true
mixup_epoch: 250
EvalReader:
......
......@@ -30,6 +30,7 @@ class Compose(object):
if hasattr(op_cls, 'num_classes'):
op_cls.num_classes = num_classes
# TODO: should be refined in the future
if op_cls in [
transform.Gt2YoloTargetOp, transform.Gt2YoloTarget
]:
......@@ -89,7 +90,8 @@ class BaseDataLoader(object):
drop_last=False,
drop_empty=True,
num_classes=81,
with_background=True):
with_background=True,
**kwargs):
# out fields
self._fields = inputs_def['fields'] if inputs_def else None
# sample transform
......@@ -107,6 +109,7 @@ class BaseDataLoader(object):
self.shuffle = shuffle
self.drop_last = drop_last
self.with_background = with_background
self.kwargs = kwargs
def __call__(self,
dataset,
......@@ -120,6 +123,8 @@ class BaseDataLoader(object):
# get data
self._dataset.set_out(self._sample_transforms,
copy.deepcopy(self._fields))
# set kwargs
self._dataset.set_kwargs(**self.kwargs)
# batch sampler
if batch_sampler is None:
self._batch_sampler = DistributedBatchSampler(
......@@ -154,10 +159,12 @@ class TrainReader(BaseDataLoader):
drop_last=True,
drop_empty=True,
num_classes=81,
with_background=True):
super(TrainReader, self).__init__(
inputs_def, sample_transforms, batch_transforms, batch_size,
shuffle, drop_last, drop_empty, num_classes, with_background)
with_background=True,
**kwargs):
super(TrainReader, self).__init__(inputs_def, sample_transforms,
batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
@register
......@@ -171,10 +178,12 @@ class EvalReader(BaseDataLoader):
drop_last=True,
drop_empty=True,
num_classes=81,
with_background=True):
super(EvalReader, self).__init__(
inputs_def, sample_transforms, batch_transforms, batch_size,
shuffle, drop_last, drop_empty, num_classes, with_background)
with_background=True,
**kwargs):
super(EvalReader, self).__init__(inputs_def, sample_transforms,
batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
@register
......@@ -188,7 +197,9 @@ class TestReader(BaseDataLoader):
drop_last=False,
drop_empty=True,
num_classes=81,
with_background=True):
super(TestReader, self).__init__(
inputs_def, sample_transforms, batch_transforms, batch_size,
shuffle, drop_last, drop_empty, num_classes, with_background)
with_background=True,
**kwargs):
super(TestReader, self).__init__(inputs_def, sample_transforms,
batch_transforms, batch_size, shuffle,
drop_last, drop_empty, num_classes,
with_background, **kwargs)
......@@ -28,18 +28,9 @@ class COCODataSet(DetDataset):
dataset_dir=None,
image_dir=None,
anno_path=None,
mixup_epoch=-1,
cutmix_epoch=-1,
mosaic_epoch=-1,
sample_num=-1):
super(COCODataSet, self).__init__(
dataset_dir,
image_dir,
anno_path,
sample_num,
mixup_epoch=mixup_epoch,
cutmix_epoch=cutmix_epoch,
mosaic_epoch=mosaic_epoch)
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
sample_num)
self.load_image_only = False
self.load_semantic = False
......
......@@ -33,9 +33,6 @@ class DetDataset(Dataset):
anno_path=None,
sample_num=-1,
use_default_label=None,
mixup_epoch=-1,
cutmix_epoch=-1,
mosaic_epoch=-1,
**kwargs):
super(DetDataset, self).__init__()
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
......@@ -44,9 +41,6 @@ class DetDataset(Dataset):
self.sample_num = sample_num
self.use_default_label = use_default_label
self.epoch = 0
self.mixup_epoch = mixup_epoch
self.cutmix_epoch = cutmix_epoch
self.mosaic_epoch = mosaic_epoch
def __len__(self, ):
return len(self.roidbs)
......@@ -77,6 +71,11 @@ class DetDataset(Dataset):
out[k] = roidb[k]
return out.values()
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
def set_out(self, sample_transform, fields):
self.transform = sample_transform
self.fields = fields
......
......@@ -4,7 +4,7 @@ import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from ppdet.modeling.ops import BatchNorm
from ppdet.modeling.ops import batch_norm
__all__ = ['DarkNet', 'ConvBNLayer']
......@@ -31,7 +31,7 @@ class ConvBNLayer(nn.Layer):
groups=groups,
weight_attr=ParamAttr(name=name + '.conv.weights'),
bias_attr=False)
self.batch_norm = BatchNorm(ch_out, norm_type=norm_type, name=name)
self.batch_norm = batch_norm(ch_out, norm_type=norm_type, name=name)
self.act = act
def forward(self, inputs):
......
......@@ -29,41 +29,27 @@ import numpy as np
from functools import reduce
__all__ = [
'roi_pool',
'roi_align',
'prior_box',
'anchor_generator',
'generate_proposals',
'iou_similarity',
'box_coder',
'yolo_box',
'multiclass_nms',
'distribute_fpn_proposals',
'collect_fpn_proposals',
'matrix_nms',
'BatchNorm',
'roi_pool', 'roi_align', 'prior_box', 'anchor_generator',
'generate_proposals', 'iou_similarity', 'box_coder', 'yolo_box',
'multiclass_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals',
'matrix_nms', 'batch_norm'
]
class BatchNorm(nn.Layer):
def __init__(self, ch, norm_type='bn', name=None):
super(BatchNorm, self).__init__()
def batch_norm(ch, norm_type='bn', name=None):
bn_name = name + '.bn'
if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm
else:
batch_norm = nn.BatchNorm2D
self.batch_norm = batch_norm(
return batch_norm(
ch,
weight_attr=ParamAttr(
name=bn_name + '.scale', regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
name=bn_name + '.offset', regularizer=L2Decay(0.)))
def forward(self, x):
return self.batch_norm(x)
def roi_pool(input,
rois,
......
......@@ -156,7 +156,7 @@ def run(FLAGS, cfg, place):
start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch']
for epoch_id in range(int(cfg.epoch)):
cur_eid = epoch_id + start_epoch
train_loader.dataset.epoch = epoch_id
train_loader.dataset.epoch = cur_eid
for iter_id, data in enumerate(train_loader):
start_time = end_time
end_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册