提交 5bb37edc 编写于 作者: J jerrywgz 提交者: qingqing01

Add Fasterrcnn model (#1211)

* Object detection model: faster-rcnn  by @jerrywgz @sefira @qingqing01
上级 5d60fb41
output/
*.swp
*.log
log*
output*
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# Detectron
# Copyright (c) 2017-present, Facebook, Inc.
# Licensed under the Apache License, Version 2.0;
# Written by Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
def xywh_to_xyxy(xywh):
"""Convert [x1 y1 w h] box format to [x1 y1 x2 y2] format."""
if isinstance(xywh, (list, tuple)):
# Single box given as a list of coordinates
assert len(xywh) == 4
x1, y1 = xywh[0], xywh[1]
x2 = x1 + np.maximum(0., xywh[2] - 1.)
y2 = y1 + np.maximum(0., xywh[3] - 1.)
return (x1, y1, x2, y2)
elif isinstance(xywh, np.ndarray):
# Multiple boxes given as a 2D ndarray
return np.hstack(
(xywh[:, 0:2], xywh[:, 0:2] + np.maximum(0, xywh[:, 2:4] - 1)))
else:
raise TypeError('Argument xywh must be a list, tuple, or numpy array.')
def xyxy_to_xywh(xyxy):
"""Convert [x1 y1 x2 y2] box format to [x1 y1 w h] format."""
if isinstance(xyxy, (list, tuple)):
# Single box given as a list of coordinates
assert len(xyxy) == 4
x1, y1 = xyxy[0], xyxy[1]
w = xyxy[2] - x1 + 1
h = xyxy[3] - y1 + 1
return (x1, y1, w, h)
elif isinstance(xyxy, np.ndarray):
# Multiple boxes given as a 2D ndarray
return np.hstack((xyxy[:, 0:2], xyxy[:, 2:4] - xyxy[:, 0:2] + 1))
else:
raise TypeError('Argument xyxy must be a list, tuple, or numpy array.')
def clip_xyxy_to_image(x1, y1, x2, y2, height, width):
"""Clip coordinates to an image with the given height and width."""
x1 = np.minimum(width - 1., np.maximum(0., x1))
y1 = np.minimum(height - 1., np.maximum(0., y1))
x2 = np.minimum(width - 1., np.maximum(0., x2))
y2 = np.minimum(height - 1., np.maximum(0., y2))
return x1, y1, x2, y2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# Detectron
# Copyright (c) 2017-present, Facebook, Inc.
# Licensed under the Apache License, Version 2.0;
# Written by Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import cv2
import numpy as np
def get_image_blob(roidb, settings):
"""Builds an input blob from the images in the roidb at the specified
scales.
"""
scale_ind = np.random.randint(0, high=len(settings.scales))
im = cv2.imread(roidb['image'])
assert im is not None, \
'Failed to read image \'{}\''.format(roidb['image'])
if roidb['flipped']:
im = im[:, ::-1, :]
#print(im[10:, 10:, :])
target_size = settings.scales[scale_ind]
im, im_scale = prep_im_for_blob(im, settings.mean_value, target_size,
settings.max_size)
return im, im_scale
def prep_im_for_blob(im, pixel_means, target_size, max_size):
"""Prepare an image for use as a network input blob. Specially:
- Subtract per-channel pixel mean
- Convert to float32
- Rescale to each of the specified target size (capped at max_size)
Returns a list of transformed images, one for each target size. Also returns
the scale factors that were used to compute each returned image.
"""
im = im.astype(np.float32, copy=False)
im -= pixel_means
#print(im[10:, 10:, :])
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(target_size) / float(im_size_min)
# Prevent the biggest axis from being more than max_size
if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max)
im = cv2.resize(
im,
None,
None,
fx=im_scale,
fy=im_scale,
interpolation=cv2.INTER_LINEAR)
im_height, im_width, channel = im.shape
padding_im = np.zeros((max_size, max_size, im_shape[2]), dtype=np.float32)
padding_im[:im_height, :im_width, :] = im
#print(padding_im[10:, 10:, :])
channel_swap = (2, 0, 1) #(batch, channel, height, width)
#im = im.transpose(channel_swap)
padding_im = padding_im.transpose(channel_swap)
#print(padding_im[10:, 10:, :])
return padding_im, im_scale
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the data.
echo "Downloading..."
wget http://images.cocodataset.org/zips/train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
# Extract the data.
echo "Extracting..."
unzip train2014.zip
unzip val2014.zip
unzip train2017.zip
unzip val2017.zip
unzip annotations_trainval2014.zip
unzip annotations_trainval2017.zip
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.regularizer import L2Decay
import paddle.fluid as fluid
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
name=None):
conv1 = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=ParamAttr(name=name + "_biases"),
name=name + '.conv2d.output.1')
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv1,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
is_test=True)
def conv_affine_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1')
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=ParamAttr(
name=bn_name + '_scale', learning_rate=0.),
default_initializer=Constant(1.))
scale.stop_gradient = True
bias = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=ParamAttr(
bn_name + '_offset', learning_rate=0.),
default_initializer=Constant(0.))
bias.stop_gradient = True
elt_mul = fluid.layers.elementwise_mul(x=conv, y=scale, axis=1)
out = fluid.layers.elementwise_add(x=elt_mul, y=bias, axis=1)
if act == 'relu':
out = fluid.layers.relu(x=out)
return out
def shortcut(input, ch_out, stride, name):
ch_in = input.shape[1] # if args.data_format == 'NCHW' else input.shape[-1]
if ch_in != ch_out:
return conv_affine_layer(input, ch_out, 1, stride, 0, None, name=name)
else:
return input
def basicblock(input, ch_out, stride, name):
short = shortcut(input, ch_out, stride, name=name)
conv1 = conv_affine_layer(input, ch_out, 3, stride, 1, name=name)
conv2 = conv_affine_layer(conv1, ch_out, 3, 1, 1, act=None, name=name)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name)
def bottleneck(input, ch_out, stride, name):
short = shortcut(input, ch_out * 4, stride, name=name + "_branch1")
conv1 = conv_affine_layer(
input, ch_out, 1, stride, 0, name=name + "_branch2a")
conv2 = conv_affine_layer(conv1, ch_out, 3, 1, 1, name=name + "_branch2b")
conv3 = conv_affine_layer(
conv2, ch_out * 4, 1, 1, 0, act=None, name=name + "_branch2c")
return fluid.layers.elementwise_add(
x=short, y=conv3, act='relu', name=name + ".add.output.5")
def layer_warp(block_func, input, ch_out, count, stride, name):
res_out = block_func(input, ch_out, stride, name=name + "a")
for i in range(1, count):
res_out = block_func(res_out, ch_out, 1, name=name + chr(ord("a") + i))
return res_out
def FasterRcnn(input, depth, anchor_sizes, variance, aspect_ratios, gt_box,
is_crowd, gt_label, im_info, class_nums, use_random):
cfg = {
18: ([2, 2, 2, 1], basicblock),
34: ([3, 4, 6, 3], basicblock),
50: ([3, 4, 6, 3], bottleneck),
101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck)
}
stages, block_func = cfg[depth]
conv1 = conv_affine_layer(
input, ch_out=64, filter_size=7, stride=2, padding=3, name="conv1")
pool1 = fluid.layers.pool2d(
input=conv1,
pool_type='max',
pool_size=3,
pool_stride=2,
pool_padding=1,
name="pool1.max_pool.output.1")
res2 = layer_warp(block_func, pool1, 64, stages[0], 1, name="res2")
res2.stop_gradient = True
res3 = layer_warp(block_func, res2, 128, stages[1], 2, name="res3")
res4 = layer_warp(block_func, res3, 256, stages[2], 2, name="res4")
#========= RPN ============
# rpn_conv/3*3
rpn_conv = fluid.layers.conv2d(
input=res4,
num_filters=1024,
filter_size=3,
stride=1,
padding=1,
act='relu',
name='conv_rpn',
param_attr=ParamAttr(name="conv_rpn_w"),
bias_attr=ParamAttr(
name="conv_rpn_b", learning_rate=2., regularizer=L2Decay(0.)))
anchor, var = fluid.layers.anchor_generator(
input=rpn_conv,
anchor_sizes=anchor_sizes,
aspect_ratios=aspect_ratios,
variance=variance,
stride=[16.0, 16.0])
num_anchor = anchor.shape[2]
rpn_cls_score = fluid.layers.conv2d(
rpn_conv,
num_filters=num_anchor,
filter_size=1,
stride=1,
padding=0,
act=None,
name='rpn_cls_score',
param_attr=ParamAttr(name="rpn_cls_logits_w"),
bias_attr=ParamAttr(
name="rpn_cls_logits_b", learning_rate=2., regularizer=L2Decay(0.)))
rpn_bbox_pred = fluid.layers.conv2d(
rpn_conv,
num_filters=4 * num_anchor,
filter_size=1,
stride=1,
padding=0,
act=None,
name='rpn_bbox_pred',
param_attr=ParamAttr(name="rpn_bbox_pred_w"),
bias_attr=ParamAttr(
name="rpn_bbox_pred_b", learning_rate=2., regularizer=L2Decay(0.)))
rpn_cls_score_prob = fluid.layers.sigmoid(
rpn_cls_score, name='rpn_cls_score_prob')
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
scores=rpn_cls_score_prob,
bbox_deltas=rpn_bbox_pred,
im_info=im_info,
anchors=anchor,
variances=var,
pre_nms_top_n=12000,
post_nms_top_n=2000,
nms_thresh=0.7,
min_size=0.0,
eta=1.0)
rois, labels_int32, bbox_targets, bbox_inside_weights, \
bbox_outside_weights = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
gt_classes=gt_label,
is_crowd=is_crowd,
gt_boxes=gt_box,
im_info=im_info,
batch_size_per_im=512,
fg_fraction=0.25,
fg_thresh=0.5,
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums,
use_random=use_random)
rois.stop_gradient = True
labels_int32.stop_gradient = True
bbox_targets.stop_gradient = True
bbox_inside_weights.stop_gradient = True
bbox_outside_weights.stop_gradient = True
pool5 = fluid.layers.roi_pool(
input=res4,
rois=rois,
pooled_height=14,
pooled_width=14,
spatial_scale=0.0625)
res5_2_sum = layer_warp(block_func, pool5, 512, stages[3], 2, name="res5")
res5_pool = fluid.layers.pool2d(
res5_2_sum, pool_type='avg', pool_size=7, name='res5_pool')
cls_score = fluid.layers.fc(input=res5_pool,
size=class_nums,
act=None,
name='cls_score',
param_attr=ParamAttr(name='cls_score_w'),
bias_attr=ParamAttr(
name='cls_score_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
bbox_pred = fluid.layers.fc(input=res5_pool,
size=4 * class_nums,
act=None,
name='bbox_pred',
param_attr=ParamAttr(name='bbox_pred_w'),
bias_attr=ParamAttr(
name='bbox_pred_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
return rpn_cls_score, rpn_bbox_pred, anchor, var, cls_score,\
bbox_pred, bbox_targets, bbox_inside_weights, \
bbox_outside_weights, rois, labels_int32
def RPNloss(rpn_cls_prob, rpn_bbox_pred, anchor, var, gt_box, is_crowd, im_info,
use_random):
rpn_cls_score_reshape = fluid.layers.transpose(
rpn_cls_prob, perm=[0, 2, 3, 1])
rpn_bbox_pred_reshape = fluid.layers.transpose(
rpn_bbox_pred, perm=[0, 2, 3, 1])
anchor_reshape = fluid.layers.reshape(anchor, shape=(-1, 4))
var_reshape = fluid.layers.reshape(var, shape=(-1, 4))
rpn_cls_score_reshape = fluid.layers.reshape(
x=rpn_cls_score_reshape, shape=(0, -1, 1))
rpn_bbox_pred_reshape = fluid.layers.reshape(
x=rpn_bbox_pred_reshape, shape=(0, -1, 4))
score_pred, loc_pred, score_target, loc_target = fluid.layers.rpn_target_assign(
bbox_pred=rpn_bbox_pred_reshape,
cls_logits=rpn_cls_score_reshape,
anchor_box=anchor_reshape,
anchor_var=var_reshape,
gt_boxes=gt_box,
is_crowd=is_crowd,
im_info=im_info,
rpn_batch_size_per_im=256,
rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3,
use_random=use_random)
score_target = fluid.layers.cast(x=score_target, dtype='float32')
rpn_cls_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=score_pred, label=score_target)
rpn_cls_loss = fluid.layers.reduce_mean(rpn_cls_loss, name='loss_rpn_cls')
rpn_reg_loss = fluid.layers.smooth_l1(x=loc_pred, y=loc_target, sigma=3.0)
rpn_reg_loss = fluid.layers.reduce_sum(rpn_reg_loss, name='loss_rpn_bbox')
score_shape = fluid.layers.shape(score_target)
score_shape = fluid.layers.cast(x=score_shape, dtype='float32')
norm = fluid.layers.reduce_prod(score_shape)
norm.stop_gradient = True
rpn_reg_loss = rpn_reg_loss / norm
return rpn_cls_loss, rpn_reg_loss
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler
from paddle.fluid.layers import control_flow
def exponential_with_warmup_decay(learning_rate, boundaries, values,
warmup_iter, warmup_factor):
global_step = lr_scheduler._decay_step_counter()
lr = fluid.layers.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_iter_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warmup_iter), force_cpu=True)
with control_flow.Switch() as switch:
with switch.case(global_step < warmup_iter_var):
alpha = global_step / warmup_iter_var
factor = warmup_factor * (1 - alpha) + alpha
decayed_lr = learning_rate * factor
fluid.layers.assign(decayed_lr, lr)
for i in range(len(boundaries)):
boundary_val = fluid.layers.fill_constant(
shape=[1],
dtype='float32',
value=float(boundaries[i]),
force_cpu=True)
value_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
fluid.layers.assign(value_var, lr)
last_value_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(values[len(values) - 1]))
with switch.default():
fluid.layers.assign(last_value_var, lr)
return lr
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils.image_util import *
import random
from PIL import Image
from PIL import ImageDraw
import numpy as np
import xml.etree.ElementTree
import os
import time
import copy
import six
from roidbs import JsonDataset
import data_utils
class Settings(object):
def __init__(self, args=None):
for arg, value in sorted(six.iteritems(vars(args))):
setattr(self, arg, value)
if 'coco2014' in args.dataset:
self.class_nums = 81
self.train_file_list = 'annotations/instances_train2014.json'
self.train_data_dir = 'train2014'
self.val_file_list = 'annotations/instances_val2014.json'
self.val_data_dir = 'val2014'
elif 'coco2017' in args.dataset:
self.class_nums = 81
self.train_file_list = 'annotations/instances_train2017.json'
self.train_data_dir = 'train2017'
self.val_file_list = 'annotations/instances_val2017.json'
self.val_data_dir = 'val2017'
else:
raise NotImplementedError('Dataset {} not supported'.format(
self.dataset))
self.mean_value = np.array(self.mean_value)[
np.newaxis, np.newaxis, :].astype('float32')
def coco(settings, mode, shuffle):
if mode == 'train':
settings.train_file_list = os.path.join(settings.data_dir,
settings.train_file_list)
settings.train_data_dir = os.path.join(settings.data_dir,
settings.train_data_dir)
elif mode == 'test':
settings.val_file_list = os.path.join(settings.data_dir,
settings.val_file_list)
settings.val_data_dir = os.path.join(settings.data_dir,
settings.val_data_dir)
json_dataset = JsonDataset(settings, train=(mode == 'train'))
roidbs = json_dataset.get_roidb()
print("{} on {} with {} roidbs".format(mode, settings.dataset, len(roidbs)))
def reader():
if mode == "train" and shuffle:
random.shuffle(roidbs)
im_out, gt_boxes_out, gt_classes_out, is_crowd_out, im_info_out = [],[],[],[],[]
lod = [0]
for roidb in roidbs:
im, im_scales = data_utils.get_image_blob(roidb, settings)
im_height = np.round(roidb['height'] * im_scales)
im_width = np.round(roidb['width'] * im_scales)
im_info = np.array(
[im_height, im_width, im_scales], dtype=np.float32)
gt_boxes = roidb['gt_boxes'].astype('float32')
gt_classes = roidb['gt_classes'].astype('int32')
is_crowd = roidb['is_crowd'].astype('int32')
if gt_boxes.shape[0] == 0:
continue
im_out.append(im)
gt_boxes_out.extend(gt_boxes)
gt_classes_out.extend(gt_classes)
is_crowd_out.extend(is_crowd)
im_info_out.append(im_info)
lod.append(lod[-1] + gt_boxes.shape[0])
if len(im_out) == settings.batch_size:
im_out = np.array(im_out).astype('float32')
gt_boxes_out = np.array(gt_boxes_out).astype('float32')
gt_classes_out = np.array(gt_classes_out).astype('int32')
is_crowd_out = np.array(is_crowd_out).astype('int32')
im_info_out = np.array(im_info_out).astype('float32')
yield im_out, gt_boxes_out, gt_classes_out, is_crowd_out, im_info_out, lod
im_out, gt_boxes_out, gt_classes_out, is_crowd_out, im_info_out = [],[],[],[],[]
lod = [0]
return reader
def train(settings, shuffle=True):
return coco(settings, 'train', shuffle)
def test(settings):
return coco(settings, 'test', False)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# Detectron
# Copyright (c) 2017-present, Facebook, Inc.
# Licensed under the Apache License, Version 2.0;
# Written by Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
import cPickle as pickle
import logging
import numpy as np
import os
import scipy.sparse
import random
import time
import matplotlib
matplotlib.use('Agg')
from pycocotools.coco import COCO
import box_utils
logger = logging.getLogger(__name__)
class JsonDataset(object):
"""A class representing a COCO json dataset."""
def __init__(self, args, train=False):
logger.debug('Creating: {}'.format(args.dataset))
self.name = args.dataset
self.is_train = train
if self.is_train:
data_dir = args.train_data_dir
file_list = args.train_file_list
else:
data_dir = args.val_data_dir
file_list = args.val_file_list
self.image_directory = data_dir
self.COCO = COCO(file_list)
# Set up dataset classes
category_ids = self.COCO.getCatIds()
categories = [c['name'] for c in self.COCO.loadCats(category_ids)]
self.category_to_id_map = dict(zip(categories, category_ids))
self.classes = ['__background__'] + categories
self.num_classes = len(self.classes)
self.json_category_id_to_contiguous_id = {
v: i + 1
for i, v in enumerate(self.COCO.getCatIds())
}
self.contiguous_category_id_to_json_id = {
v: k
for k, v in self.json_category_id_to_contiguous_id.items()
}
def get_roidb(self):
"""Return an roidb corresponding to the json dataset. Optionally:
- include ground truth boxes in the roidb
- add proposals specified in a proposals file
- filter proposals based on a minimum side length
- filter proposals that intersect with crowd regions
"""
image_ids = self.COCO.getImgIds()
image_ids.sort()
roidb = copy.deepcopy(self.COCO.loadImgs(image_ids))
for entry in roidb:
self._prep_roidb_entry(entry)
if self.is_train:
# Include ground-truth object annotations
start_time = time.time()
for entry in roidb:
self._add_gt_annotations(entry)
end_time = time.time()
logger.debug('_add_gt_annotations took {:.3f}s'.format(end_time -
start_time))
logger.info('Appending horizontally-flipped training examples...')
self._extend_with_flipped_entries(roidb)
logger.info('Loaded dataset: {:s}'.format(self.name))
logger.info('{:d} roidb entries'.format(len(roidb)))
return roidb
def _prep_roidb_entry(self, entry):
"""Adds empty metadata fields to an roidb entry."""
# Make file_name an abs path
im_path = os.path.join(self.image_directory, entry['file_name'])
#assert os.path.exists(im_path), 'Image \'{}\' not found'.format(im_path)
entry['image'] = im_path
entry['flipped'] = False
# Empty placeholders
entry['gt_boxes'] = np.empty((0, 4), dtype=np.float32)
entry['gt_classes'] = np.empty((0), dtype=np.int32)
entry['gt_id'] = np.empty((0), dtype=np.int32)
entry['is_crowd'] = np.empty((0), dtype=np.bool)
# Remove unwanted fields that come from the json file (if they exist)
for k in ['date_captured', 'url', 'license', 'file_name']:
if k in entry:
del entry[k]
def _add_gt_annotations(self, entry):
"""Add ground truth annotation metadata to an roidb entry."""
count = 0
#for k in self.category_to_id_map:
# imgs = self.COCO.getImgIds(catIds=(self.category_to_id_map[k]))
# count += len(imgs)
ann_ids = self.COCO.getAnnIds(imgIds=entry['id'], iscrowd=None)
objs = self.COCO.loadAnns(ann_ids)
# Sanitize bboxes -- some are invalid
valid_objs = []
width = entry['width']
height = entry['height']
for obj in objs:
if obj['area'] < -1: #cfg.TRAIN.GT_MIN_AREA:
continue
if 'ignore' in obj and obj['ignore'] == 1:
continue
# Convert form (x1, y1, w, h) to (x1, y1, x2, y2)
x1, y1, x2, y2 = box_utils.xywh_to_xyxy(obj['bbox'])
x1, y1, x2, y2 = box_utils.clip_xyxy_to_image(x1, y1, x2, y2,
height, width)
# Require non-zero seg area and more than 1x1 box size
if obj['area'] > 0 and x2 > x1 and y2 > y1:
obj['clean_bbox'] = [x1, y1, x2, y2]
valid_objs.append(obj)
num_valid_objs = len(valid_objs)
gt_boxes = np.zeros((num_valid_objs, 4), dtype=entry['gt_boxes'].dtype)
gt_id = np.zeros((num_valid_objs), dtype=np.int32)
gt_classes = np.zeros((num_valid_objs), dtype=entry['gt_classes'].dtype)
is_crowd = np.zeros((num_valid_objs), dtype=entry['is_crowd'].dtype)
for ix, obj in enumerate(valid_objs):
cls = self.json_category_id_to_contiguous_id[obj['category_id']]
gt_boxes[ix, :] = obj['clean_bbox']
gt_classes[ix] = cls
gt_id[ix] = np.int32(obj['id'])
is_crowd[ix] = obj['iscrowd']
entry['gt_boxes'] = np.append(entry['gt_boxes'], gt_boxes, axis=0)
entry['gt_classes'] = np.append(entry['gt_classes'], gt_classes)
entry['gt_id'] = np.append(entry['gt_id'], gt_id)
entry['is_crowd'] = np.append(entry['is_crowd'], is_crowd)
def _extend_with_flipped_entries(self, roidb):
"""Flip each entry in the given roidb and return a new roidb that is the
concatenation of the original roidb and the flipped entries.
"Flipping" an entry means that that image and associated metadata (e.g.,
ground truth boxes and object proposals) are horizontally flipped.
"""
flipped_roidb = []
for entry in roidb:
width = entry['width']
gt_boxes = entry['gt_boxes'].copy()
oldx1 = gt_boxes[:, 0].copy()
oldx2 = gt_boxes[:, 2].copy()
gt_boxes[:, 0] = width - oldx2 - 1
gt_boxes[:, 2] = width - oldx1 - 1
assert (gt_boxes[:, 2] >= gt_boxes[:, 0]).all()
flipped_entry = {}
dont_copy = ('gt_boxes', 'flipped')
for k, v in entry.items():
if k not in dont_copy:
flipped_entry[k] = v
flipped_entry['gt_boxes'] = gt_boxes
flipped_entry['flipped'] = True
flipped_roidb.append(flipped_entry)
roidb.extend(flipped_roidb)
import os
import time
import numpy as np
import argparse
import functools
import shutil
import cPickle
from utility import add_arguments, print_arguments
import paddle
import paddle.fluid as fluid
import reader
from fasterrcnn_model import FasterRcnn, RPNloss
from learning_rate import exponential_with_warmup_decay
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
# ENV
add_arg('parallel', bool, True, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'imagenet_resnet50_fusebn', "The init model path.")
add_arg('dataset', str, 'coco2017', "coco2014, coco2017, and pascalvoc.")
add_arg('data_dir', str, 'data/COCO17', "data directory")
# SOLVER
add_arg('learning_rate', float, 0.01, "Learning rate.")
add_arg('num_passes', int, 20, "Epoch number.")
# RPN
add_arg('anchor_sizes', int, [32,64,128,256,512], "The size of anchors.")
add_arg('aspect_ratios', float, [0.5,1.0,2.0], "The ratio of anchors.")
add_arg('rpn_stride', float, 16., "Stride of the feature map that RPN is attached.")
# FAST RCNN
# TRAIN TEST
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('max_size', int, 1333, "The max resized image size.")
add_arg('scales', int, [800], "The resized image height.")
add_arg('batch_size_per_im',int, 512, "fast rcnn head batch size")
add_arg('mean_value', float, [102.9801, 115.9465, 122.7717], "pixel mean")
add_arg('debug', bool, False, "Debug mode")
#yapf: enable
def train(args):
num_passes = args.num_passes
batch_size = args.batch_size
learning_rate = args.learning_rate
image_shape = [3, args.max_size, args.max_size]
if args.debug:
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
import random
random.seed(0)
np.random.seed(0)
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1)
gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='int32', lod_level=1)
is_crowd = fluid.layers.data(
name='is_crowd', shape = [-1], dtype='int32', lod_level=1, append_batch_size=False)
im_info = fluid.layers.data(
name='im_info', shape=[3], dtype='float32')
rpn_cls_score, rpn_bbox_pred, anchor, var, cls_score, bbox_pred,\
bbox_targets, bbox_inside_weights, bbox_outside_weights, rois, \
labels_int32 = FasterRcnn(
input=image,
depth=50,
anchor_sizes=[32,64,128,256,512],
variance=[1.,1.,1.,1.],
aspect_ratios=[0.5,1.0,2.0],
gt_box=gt_box,
is_crowd=is_crowd,
gt_label=gt_label,
im_info=im_info,
class_nums=args.class_nums,
use_random=False if args.debug else True
)
cls_loss, reg_loss = RPNloss(rpn_cls_score, rpn_bbox_pred, anchor, var, \
gt_box, is_crowd, im_info, use_random=False if args.debug else True)
cls_loss.persistable=True
reg_loss.persistable=True
rpn_loss = cls_loss + reg_loss
rpn_loss.persistable=True
labels_int64 = fluid.layers.cast(x=labels_int32, dtype='int64')
labels_int64.stop_gradient = True
#loss_cls = fluid.layers.softmax_with_cross_entropy(
# logits=cls_score,
# label=labels_int64
# )
softmax = fluid.layers.softmax(cls_score, use_cudnn=False)
loss_cls = fluid.layers.cross_entropy(softmax, labels_int64)
loss_cls = fluid.layers.reduce_mean(loss_cls)
loss_cls.persistable=True
loss_bbox = fluid.layers.smooth_l1(x=bbox_pred,
y=bbox_targets,
inside_weight=bbox_inside_weights,
outside_weight=bbox_outside_weights,
sigma=1.0)
loss_bbox = fluid.layers.reduce_mean(loss_bbox)
loss_bbox.persistable=True
loss_cls.persistable=True
loss_bbox.persistable=True
detection_loss = loss_cls + loss_bbox
detection_loss.persistable=True
loss = rpn_loss + detection_loss
loss.persistable=True
boundaries = [120000, 160000]
values = [learning_rate, learning_rate*0.1, learning_rate*0.01]
optimizer = fluid.optimizer.Momentum(
learning_rate=exponential_with_warmup_decay(learning_rate=learning_rate,
boundaries=boundaries,
values=values,
warmup_iter=500,
warmup_factor=1.0/3.0),
regularization=fluid.regularizer.L2Decay(0.0001),
momentum=0.9)
optimizer.minimize(loss)
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
if args.parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=bool(args.use_gpu), loss_name=loss.name)
train_reader = reader.train(args)
def save_model(postfix):
model_path = os.path.join(args.model_save_dir, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
fluid.io.save_persistables(exe, model_path)
fetch_list = [loss, cls_loss, reg_loss, loss_cls, loss_bbox]
def tensor(data, place, lod=None):
t = fluid.core.LoDTensor()
t.set(data, place)
if lod:
t.set_lod(lod)
return t
total_time = 0.0
for epoc_id in range(num_passes):
start_time = time.time()
prev_start_time = start_time
every_pass_loss = []
iter = 0
pass_duration = 0.0
for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
image, gt_box, gt_label, is_crowd, im_info, lod = data
image_t = tensor(image, place)
gt_box_t = tensor(gt_box, place, [lod])
gt_label_t = tensor(gt_label, place, [lod])
is_crowd_t = tensor(is_crowd, place, [lod])
im_info_t = tensor(im_info, place)
feeding = {}
feeding['image'] = image_t
feeding['gt_box'] = gt_box_t
feeding['gt_label'] = gt_label_t
feeding['is_crowd'] = is_crowd_t
feeding['im_info'] = im_info_t
if args.parallel:
losses = train_exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeding)
else:
losses = exe.run(fluid.default_main_program(),
feed=feeding,
fetch_list=fetch_list)
loss_v = np.mean(np.array(losses[0]))
every_pass_loss.append(loss_v)
lr = np.array(fluid.global_scope().find_var('learning_rate').get_tensor())
if batch_id % 1 == 0:
print("Epoc {:d}, batch {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format(
epoc_id, batch_id, lr[0], losses[0][0], start_time - prev_start_time))
#print('cls_loss ', losses[1][0], ' reg_loss ', losses[2][0], ' loss_cls ', losses[3][0], ' loss_bbox ', losses[4][0])
if epoc_id % 10 == 0 or epoc_id == num_passes - 1:
save_model(str(epoc_id))
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_args = reader.Settings(args)
train(data_args)
"""Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import numpy as np
import six
from paddle.fluid import core
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册