提交 97a365e5 编写于 作者: D dengkaipeng

add yolov3

上级 4d22fee0
......@@ -1125,19 +1125,19 @@ class Model(fluid.dygraph.Layer):
if not isinstance(test_loader, Iterable):
loader = test_loader()
outputs = None
outputs = []
for data in tqdm.tqdm(loader):
if not fluid.in_dygraph_mode():
data = data[0]
outs = self.test(*data)
assert len(data) == len(self._inputs) + len(self._labels), \
"data fileds number mismatch"
inputs_data = data[:len(self._inputs)]
if outputs is None:
outputs = outs
else:
outputs = [
np.vstack([x, outs[i]]) for i, x in enumerate(outputs)
]
outputs.append(self.test(inputs_data))
# sample list to batched data
outputs = list(zip(*outputs))
self._test_dataloader = None
if test_loader is not None and self._adapter._nranks > 1 \
......@@ -1180,11 +1180,16 @@ class Model(fluid.dygraph.Layer):
else:
batch_size = data[0].shape[0]
assert len(data) == len(self._inputs) + len(self._labels), \
"data fileds number mismatch"
inputs_data = data[:len(self._inputs)]
labels_data = data[len(self._inputs):]
callbacks.on_batch_begin(mode, step, logs)
if mode == 'train':
outs = self.train(*data)
outs = self.train(inputs_data, labels_data)
else:
outs = self.eval(*data)
outs = self.eval(inputs_data, labels_data)
# losses
loss = outs[0] if self._metrics else outs
......
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import cv2
import numpy as np
from pycocotools.coco import COCO
from paddle.fluid.io import Dataset
import logging
logger = logging.getLogger(__name__)
__all__ = ['COCODataset']
class COCODataset(Dataset):
"""
Load dataset with MS-COCO format.
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): voc annotation file path.
sample_num (int): number of samples to load, -1 means all.
use_default_label (bool): whether use the default mapping of
label to integer index. Default True.
with_background (bool): whether load background as a class,
default True.
transform (callable): callable transform to perform on samples,
default None.
mixup (bool): whether return image mixup samples, default False.
alpha (float): alpha factor of beta distribution to generate
mixup score, used only when mixup is True, default 1.5
beta (float): beta factor of beta distribution to generate
mixup score, used only when mixup is True, default 1.5
"""
def __init__(self,
dataset_dir='',
image_dir='',
anno_path='',
sample_num=-1,
with_background=True,
transform=None,
mixup=False,
alpha=1.5,
beta=1.5):
# roidbs is list of dict whose structure is:
# {
# 'im_file': im_fname, # image file name
# 'im_id': im_id, # image id
# 'h': im_h, # height of image
# 'w': im_w, # width
# 'is_crowd': is_crowd,
# 'gt_class': gt_class,
# 'gt_bbox': gt_bbox,
# 'gt_score': gt_score,
# 'difficult': difficult
# }
self._anno_path = os.path.join(dataset_dir, anno_path)
self._image_dir = os.path.join(dataset_dir, image_dir)
assert os.path.exists(self._anno_path), \
"anno_path {} not exists".format(anno_path)
assert os.path.exists(self._image_dir), \
"image_dir {} not exists".format(image_dir)
self._sample_num = sample_num
self._with_background = with_background
self._transform = transform
self._mixup = mixup
self._alpha = alpha
self._beta = beta
# load in dataset roidbs
self._load_roidb_and_cname2cid()
def _load_roidb_and_cname2cid(self):
assert self._anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path
coco = COCO(self._anno_path)
img_ids = coco.getImgIds()
cat_ids = coco.getCatIds()
records = []
ct = 0
# when with_background = True, mapping category to classid, like:
# background:0, first_class:1, second_class:2, ...
catid2clsid = dict({
catid: i + int(self._with_background)
for i, catid in enumerate(cat_ids)
})
cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in catid2clsid.items()
})
for img_id in img_ids:
img_anno = coco.loadImgs(img_id)[0]
im_fname = img_anno['file_name']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
instances = coco.loadAnns(ins_anno_ids)
bboxes = []
for inst in instances:
x, y, box_w, box_h = inst['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(im_w - 1, x1 + max(0, box_w - 1))
y2 = min(im_h - 1, y1 + max(0, box_h - 1))
if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
inst['clean_bbox'] = [x1, y1, x2, y2]
bboxes.append(inst)
else:
logger.warn(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
gt_score = np.ones((num_bbox, 1), dtype=np.float32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
if 'segmentation' in box:
gt_poly[i] = box['segmentation']
im_fname = os.path.join(self._image_dir,
im_fname) if self._image_dir else im_fname
coco_rec = {
'im_file': im_fname,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_score': gt_score,
'gt_poly': gt_poly,
}
records.append(coco_rec)
ct += 1
if self._sample_num > 0 and ct >= self._sample_num:
break
assert len(records) > 0, 'not found any coco record in %s' % (self._anno_path)
logger.info('{} samples in file {}'.format(ct, self._anno_path))
self._roidbs, self._cname2cid = records, cname2cid
@property
def num_classes(self):
return len(self._cname2cid)
def __len__(self):
return len(self._roidbs)
def _getitem_by_index(self, idx):
roidb = self._roidbs[idx]
with open(roidb['im_file'], 'rb') as f:
data = np.frombuffer(f.read(), dtype='uint8')
im = cv2.imdecode(data, 1)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_info = np.array([roidb['im_id'][0], roidb['h'], roidb['w']], dtype='int32')
gt_bbox = roidb['gt_bbox']
gt_class = roidb['gt_class']
gt_score = roidb['gt_score']
return im_info, im, gt_bbox, gt_class, gt_score
def __getitem__(self, idx):
im_info, im, gt_bbox, gt_class, gt_score = self._getitem_by_index(idx)
if self._mixup:
mixup_idx = idx + np.random.randint(1, self.__len__())
mixup_idx %= self.__len__()
_, mixup_im, mixup_bbox, mixup_class, _ = \
self._getitem_by_index(mixup_idx)
im, gt_bbox, gt_class, gt_score = \
self._mixup_image(im, gt_bbox, gt_class, mixup_im,
mixup_bbox, mixup_class)
if self._transform:
im_info, im, gt_bbox, gt_class, gt_score = \
self._transform(im_info, im, gt_bbox, gt_class, gt_score)
return [im_info, im, gt_bbox, gt_class, gt_score]
def _mixup_image(self, img1, bbox1, class1, img2, bbox2, class2):
factor = np.random.beta(self._alpha, self._beta)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
return img1, bbox1, class1, np.ones_like(class1, dtype="float32")
if factor <= 0.0:
return img2, bbox2, class2, np.ones_like(class2, dtype="float32")
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
img = np.zeros((h, w, img1.shape[2]), 'float32')
img[:img1.shape[0], :img1.shape[1], :] = \
img1.astype('float32') * factor
img[:img2.shape[0], :img2.shape[1], :] += \
img2.astype('float32') * (1.0 - factor)
gt_bbox = np.concatenate((bbox1, bbox2), axis=0)
gt_class = np.concatenate((class1, class2), axis=0)
score1 = np.ones_like(class1, dtype="float32") * factor
score2 = np.ones_like(class2, dtype="float32") * (1.0 - factor)
gt_score = np.concatenate((score1, score2), axis=0)
return img, gt_bbox, gt_class, gt_score
@property
def mixup(self):
return self._mixup
@mixup.setter
def mixup(self, value):
if not isinstance(value, bool):
raise ValueError("mixup should be a boolean number")
logger.info("{} set mixup to {}".format(self, value))
self._mixup = value
def pascalvoc_label(with_background=True):
labels_map = {
'aeroplane': 1,
'bicycle': 2,
'bird': 3,
'boat': 4,
'bottle': 5,
'bus': 6,
'car': 7,
'cat': 8,
'chair': 9,
'cow': 10,
'diningtable': 11,
'dog': 12,
'horse': 13,
'motorbike': 14,
'person': 15,
'pottedplant': 16,
'sheep': 17,
'sofa': 18,
'train': 19,
'tvmonitor': 20
}
if not with_background:
labels_map = {k: v - 1 for k, v in labels_map.items()}
return labels_map
......@@ -17,8 +17,6 @@ import json
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
from metrics import Metric
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
......@@ -31,7 +29,7 @@ OUTFILE = './bbox.json'
# considered to change to a callback later
class COCOMetric(Metric):
class COCOMetric():
"""
Metrci for MS-COCO dataset, only support update with batch
size as 1.
......@@ -43,7 +41,6 @@ class COCOMetric(Metric):
"""
def __init__(self, anno_path, with_background=True, **kwargs):
super(COCOMetric, self).__init__(**kwargs)
self.anno_path = anno_path
self.with_background = with_background
self.bbox_results = []
......@@ -54,15 +51,14 @@ class COCOMetric(Metric):
{i + int(with_background): catid
for i, catid in enumerate(cat_ids)})
def update(self, preds, *args, **kwargs):
im_ids, bboxes = preds
assert im_ids.shape[0] == 1, \
def update(self, img_id, bboxes):
assert img_id.shape[0] == 1, \
"COCOMetric can only update with batch size = 1"
if bboxes.shape[1] != 6:
# no bbox detected in this batch
return
im_id = int(im_ids)
img_id = int(img_id)
for i in range(bboxes.shape[0]):
dt = bboxes[i, :]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
......@@ -72,7 +68,7 @@ class COCOMetric(Metric):
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'image_id': img_id,
'category_id': catid,
'bbox': bbox,
'score': score
......@@ -95,7 +91,7 @@ class COCOMetric(Metric):
# flush coco evaluation result
sys.stdout.flush()
self.result = map_stats[0]
return self.result
return [self.result]
def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
assert coco_gt != None or anno_file != None
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from paddle.fluid.dygraph.base import to_variable
__all__ = ['DarkNet53', 'ConvBNLayer']
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=1,
groups=1,
padding=0,
act="leaky"):
super(ConvBNLayer, self).__init__()
self.conv = Conv2D(
num_channels=ch_in,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=False,
act=None)
self.batch_norm = BatchNorm(
num_channels=ch_out,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02),
regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.)))
self.act = act
def forward(self, inputs):
out = self.conv(inputs)
out = self.batch_norm(out)
if self.act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
class DownSample(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=2,
padding=1):
super(DownSample, self).__init__()
self.conv_bn_layer = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding)
self.ch_out = ch_out
def forward(self, inputs):
out = self.conv_bn_layer(inputs)
return out
class BasicBlock(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0)
self.conv2 = ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out*2,
filter_size=3,
stride=1,
padding=1)
def forward(self, inputs):
conv1 = self.conv1(inputs)
conv2 = self.conv2(conv1)
out = fluid.layers.elementwise_add(x=inputs, y=conv2, act=None)
return out
class LayerWarp(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out, count):
super(LayerWarp,self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out)
self.res_out_list = []
for i in range(1,count):
res_out = self.add_sublayer("basic_block_%d" % (i),
BasicBlock(
ch_out*2,
ch_out))
self.res_out_list.append(res_out)
self.ch_out = ch_out
def forward(self,inputs):
y = self.basicblock0(inputs)
for basic_block_i in self.res_out_list:
y = basic_block_i(y)
return y
DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
class DarkNet53(fluid.dygraph.Layer):
def __init__(self, ch_in=3):
super(DarkNet53, self).__init__()
self.stages = DarkNet_cfg[53]
self.stages = self.stages[0:5]
self.conv0 = ConvBNLayer(
ch_in=ch_in,
ch_out=32,
filter_size=3,
stride=1,
padding=1)
self.downsample0 = DownSample(
ch_in=32,
ch_out=32 * 2)
self.darknet53_conv_block_list = []
self.downsample_list = []
ch_in = [64,128,256,512,1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer(
"stage_%d" % (i),
LayerWarp(
int(ch_in[i]),
32*(2**i),
stage))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
downsample = self.add_sublayer(
"stage_%d_downsample" % i,
DownSample(
ch_in = 32*(2**(i+1)),
ch_out = 32*(2**(i+2))))
self.downsample_list.append(downsample)
def forward(self,inputs):
out = self.conv0(inputs)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
out = self.downsample_list[i](out)
return blocks[-1:-4:-1]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from model import Model, Loss
from .darknet import DarkNet53, ConvBNLayer
__all__ = ['YoloLoss', 'YOLOv3']
class YoloDetectionBlock(fluid.dygraph.Layer):
def __init__(self, ch_in, channel):
super(YoloDetectionBlock, self).__init__()
assert channel % 2 == 0, \
"channel {} cannot be divided by 2".format(channel)
self.conv0 = ConvBNLayer(
ch_in=ch_in,
ch_out=channel,
filter_size=1,
stride=1,
padding=0)
self.conv1 = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1)
self.conv2 = ConvBNLayer(
ch_in=channel*2,
ch_out=channel,
filter_size=1,
stride=1,
padding=0)
self.conv3 = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1)
self.route = ConvBNLayer(
ch_in=channel*2,
ch_out=channel,
filter_size=1,
stride=1,
padding=0)
self.tip = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1)
def forward(self, inputs):
out = self.conv0(inputs)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
route = self.route(out)
tip = self.tip(route)
return route, tip
class YOLOv3(Model):
def __init__(self, num_classes=80, model_mode='train'):
super(YOLOv3, self).__init__()
self.num_classes = num_classes
assert str.lower(model_mode) in ['train', 'eval'], \
"model_mode should be 'train' or 'val', but got " \
"{}".format(model_mode)
self.model_mode = str.lower(model_mode)
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.valid_thresh = 0.005
self.nms_thresh = 0.45
self.nms_topk = 400
self.nms_posk = 100
self.draw_thresh = 0.5
self.block = DarkNet53()
self.block_outputs = []
self.yolo_blocks = []
self.route_blocks = []
for idx, num_chan in enumerate([1024, 768, 384]):
yolo_block = self.add_sublayer(
"yolo_detecton_block_{}".format(idx),
YoloDetectionBlock(num_chan, 512 // (2**idx)))
self.yolo_blocks.append(yolo_block)
num_filters = len(self.anchor_masks[idx]) * (self.num_classes + 5)
block_out = self.add_sublayer(
"block_out_{}".format(idx),
Conv2D(num_channels=1024 // (2**idx),
num_filters=num_filters,
filter_size=1,
act=None,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.))))
self.block_outputs.append(block_out)
if idx < 2:
route = self.add_sublayer(
"route2_{}".format(idx),
ConvBNLayer(ch_in=512 // (2**idx),
ch_out=256 // (2**idx),
filter_size=1,
act='leaky_relu'))
self.route_blocks.append(route)
def forward(self, img_info, inputs):
outputs = []
boxes = []
scores = []
downsample = 32
feats = self.block(inputs)
route = None
for idx, feat in enumerate(feats):
if idx > 0:
feat = fluid.layers.concat(input=[route, feat], axis=1)
route, tip = self.yolo_blocks[idx](feat)
block_out = self.block_outputs[idx](tip)
outputs.append(block_out)
if idx < 2:
route = self.route_blocks[idx](route)
route = fluid.layers.resize_nearest(route, scale=2)
if self.model_mode == 'eval':
anchor_mask = self.anchor_masks[idx]
mask_anchors = []
for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m])
mask_anchors.append(self.anchors[2 * m + 1])
img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
b, s = fluid.layers.yolo_box(
x=block_out,
img_size=img_shape,
anchors=mask_anchors,
class_num=self.num_classes,
conf_thresh=self.valid_thresh,
downsample_ratio=downsample)
boxes.append(b)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
downsample //= 2
if self.model_mode == 'train':
return outputs
return outputs + [img_id[0, :], fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2),
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
]
class YoloLoss(Loss):
def __init__(self, num_classes=80, num_max_boxes=50):
super(YoloLoss, self).__init__()
self.num_classes = num_classes
self.num_max_boxes = num_max_boxes
self.ignore_thresh = 0.7
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
def forward(self, outputs, labels):
downsample = 32
gt_box, gt_label, gt_score = labels
losses = []
for idx, out in enumerate(outputs):
if idx == 3: break # debug
anchor_mask = self.anchor_masks[idx]
loss = fluid.layers.yolov3_loss(
x=out,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchor_mask=anchor_mask,
downsample_ratio=downsample,
anchors=self.anchors,
class_num=self.num_classes,
ignore_thresh=self.ignore_thresh,
use_label_smooth=True)
loss = fluid.layers.reduce_mean(loss)
losses.append(loss)
downsample //= 2
return losses
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册