未验证 提交 d383fd09 编写于 作者: W wangguanzhong 提交者: GitHub

add yolov3_darknet (#1299)

上级 646996f4
......@@ -130,7 +130,7 @@ LearningRate:
gamma: 0.1
milestones: [120000, 160000]
- !LinearWarmup
start_factor: 0.3333
start_factor: 0.3333333
steps: 500
OptimizerBuilder:
......
......@@ -3,13 +3,13 @@ use_gpu: true
max_iters: 500000
log_smooth_window: 20
save_dir: output
snapshot_iter: 10000
snapshot_iter: 50000
metric: COCO
pretrain_weights: https://paddlemodels.bj.bcebos.com/yolo/darknet53.pdparams
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
weights: output/yolov3_darknet/model_final
num_classes: 80
use_fine_grained_loss: false
open_debug: False
load_static_weights: True
YOLOv3:
anchor: AnchorYOLO
......@@ -18,11 +18,15 @@ YOLOv3:
DarkNet:
depth: 53
return_idx: [2, 3, 4]
YOLOv3Head:
yolo_feat:
name: YOLOFeat
feat_in_list: [1024, 768, 384]
ignore_thresh: 0.7
downsample: 32
label_smooth: true
anchor_per_position: 3
AnchorYOLO:
......@@ -30,11 +34,6 @@ AnchorYOLO:
name: AnchorGeneratorYOLO
anchors: [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchor_target_generator:
name: AnchorTargetGeneratorYOLO
ignore_thresh: 0.7
downsample_ratio: 32
label_smooth: true
anchor_post_process:
name: BBoxPostProcessYOLO
# decode -> clip
......
......@@ -27,8 +27,8 @@ TrainReader:
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
......@@ -50,8 +50,8 @@ TrainReader:
shuffle: true
mixup_epoch: 250
drop_last: true
worker_num: 8
bufsize: 16
worker_num: 4
bufsize: 4
use_process: true
......@@ -81,7 +81,7 @@ EvalReader:
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
batch_size: 1
drop_empty: false
worker_num: 8
bufsize: 16
......
......@@ -201,7 +201,7 @@ class Reader(object):
use_fine_grained_loss=False,
num_classes=80,
bufsize=-1,
memsize='3G',
memsize='500M',
inputs_def=None,
devices_num=1):
self._dataset = dataset
......
......@@ -17,38 +17,34 @@ class YOLOv3(BaseArch):
'yolo_head',
]
def __init__(self, anchor, backbone, yolo_head, *args, **kwargs):
super(YOLOv3, self).__init__(*args, **kwargs)
def __init__(self, anchor, backbone, yolo_head):
super(YOLOv3, self).__init__()
self.anchor = anchor
self.backbone = backbone
self.yolo_head = yolo_head
def model_arch(self, ):
# Backbone
bb_out = self.backbone(self.gbd)
self.gbd.update(bb_out)
body_feats = self.backbone(self.inputs)
# YOLO Head
yolo_head_out = self.yolo_head(self.gbd)
self.gbd.update(yolo_head_out)
self.yolo_head_out = self.yolo_head(body_feats)
# Anchor
anchor_out = self.anchor(self.gbd)
self.gbd.update(anchor_out)
if self.gbd['mode'] == 'infer':
bbox_out = self.anchor.post_process(self.gbd)
self.gbd.update(bbox_out)
self.anchors, self.anchor_masks, self.mask_anchors = self.anchor()
def loss(self, ):
yolo_loss = self.yolo_head.loss(self.gbd)
out = {'loss': yolo_loss}
return out
yolo_loss = self.yolo_head.loss(self.inputs, self.yolo_head_out,
self.anchors, self.anchor_masks,
self.mask_anchors)
return yolo_loss
def infer(self, ):
bbox, bbox_num = self.anchor.post_process(
self.inputs['im_size'], self.yolo_head_out, self.mask_anchors)
outs = {
"bbox": self.gbd['predicted_bbox'].numpy(),
"bbox_nums": self.gbd['predicted_bbox_nums'],
'im_id': self.gbd['im_id'].numpy()
"bbox": bbox.numpy(),
"bbox_num": bbox_num,
'im_id': self.inputs['im_id'].numpy()
}
return outs
......@@ -16,7 +16,8 @@ class ConvBNLayer(Layer):
stride=1,
groups=1,
padding=0,
act="leaky"):
act="leaky",
name=None):
super(ConvBNLayer, self).__init__()
self.conv = Conv2D(
......@@ -26,18 +27,18 @@ class ConvBNLayer(Layer):
stride=stride,
padding=padding,
groups=groups,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
param_attr=ParamAttr(name=name + '.conv.weights'),
bias_attr=False,
act=None)
bn_name = name + '.bn'
self.batch_norm = BatchNorm(
num_channels=ch_out,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02),
regularizer=L2Decay(0.)),
name=bn_name + '.scale', regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.)))
name=bn_name + '.offset', regularizer=L2Decay(0.)),
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
self.act = act
......@@ -50,7 +51,13 @@ class ConvBNLayer(Layer):
class DownSample(Layer):
def __init__(self, ch_in, ch_out, filter_size=3, stride=2, padding=1):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=2,
padding=1,
name=None):
super(DownSample, self).__init__()
......@@ -59,7 +66,8 @@ class DownSample(Layer):
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding)
padding=padding,
name=name)
self.ch_out = ch_out
def forward(self, inputs):
......@@ -68,13 +76,23 @@ class DownSample(Layer):
class BasicBlock(Layer):
def __init__(self, ch_in, ch_out):
def __init__(self, ch_in, ch_out, name=None):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(
ch_in=ch_in, ch_out=ch_out, filter_size=1, stride=1, padding=0)
ch_in=ch_in,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0,
name=name + '.0')
self.conv2 = ConvBNLayer(
ch_in=ch_out, ch_out=ch_out * 2, filter_size=3, stride=1, padding=1)
ch_in=ch_out,
ch_out=ch_out * 2,
filter_size=3,
stride=1,
padding=1,
name=name + '.1')
def forward(self, inputs):
conv1 = self.conv1(inputs)
......@@ -84,14 +102,16 @@ class BasicBlock(Layer):
class Blocks(Layer):
def __init__(self, ch_in, ch_out, count):
def __init__(self, ch_in, ch_out, count, name=None):
super(Blocks, self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out)
self.basicblock0 = BasicBlock(ch_in, ch_out, name=name + '.0')
self.res_out_list = []
for i in range(1, count):
res_out = self.add_sublayer("basic_block_%d" % (i),
BasicBlock(ch_out * 2, ch_out))
block_name = '{}.{}'.format(name, i)
res_out = self.add_sublayer(
block_name, BasicBlock(
ch_out * 2, ch_out, name=block_name))
self.res_out_list.append(res_out)
self.ch_out = ch_out
......@@ -108,31 +128,46 @@ DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
@register
@serializable
class DarkNet(Layer):
def __init__(self, depth=53, mode='train'):
def __init__(self,
depth=53,
freeze_at=-1,
return_idx=[2, 3, 4],
num_stages=5):
super(DarkNet, self).__init__()
self.depth = depth
self.mode = mode
self.stages = DarkNet_cfg[self.depth][0:5]
self.freeze_at = freeze_at
self.return_idx = return_idx
self.num_stages = num_stages
self.stages = DarkNet_cfg[self.depth][0:num_stages]
self.conv0 = ConvBNLayer(
ch_in=3, ch_out=32, filter_size=3, stride=1, padding=1)
ch_in=3,
ch_out=32,
filter_size=3,
stride=1,
padding=1,
name='yolo_input')
self.downsample0 = DownSample(ch_in=32, ch_out=32 * 2)
self.downsample0 = DownSample(
ch_in=32, ch_out=32 * 2, name='yolo_input.downsample')
self.darknet53_conv_block_list = []
self.darknet_conv_block_list = []
self.downsample_list = []
ch_in = [64, 128, 256, 512, 1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer("stage_%d" % (i),
Blocks(
int(ch_in[i]), 32 * (2**i),
stage))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
name = 'stage.{}'.format(i)
conv_block = self.add_sublayer(
name, Blocks(
int(ch_in[i]), 32 * (2**i), stage, name=name))
self.darknet_conv_block_list.append(conv_block)
for i in range(num_stages - 1):
down_name = 'stage.{}.downsample'.format(i)
downsample = self.add_sublayer(
"stage_%d_downsample" % i,
down_name,
DownSample(
ch_in=32 * (2**(i + 1)), ch_out=32 * (2**(i + 2))))
ch_in=32 * (2**(i + 1)),
ch_out=32 * (2**(i + 2)),
name=down_name))
self.downsample_list.append(downsample)
def forward(self, inputs):
......@@ -141,10 +176,12 @@ class DarkNet(Layer):
out = self.conv0(x)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
for i, conv_block_i in enumerate(self.darknet_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
if i == self.freeze_at:
out.stop_gradient = True
if i in self.return_idx:
blocks.append(out)
if i < self.num_stages - 1:
out = self.downsample_list[i](out)
outs = {'darknet_outs': blocks[-1:-4:-1]}
return outs
return blocks
......@@ -79,27 +79,25 @@ class BBoxPostProcessYOLO(object):
self.decode = decode
self.clip = clip
def __call__(self, inputs):
def __call__(self, im_size, yolo_head_out, mask_anchors):
# TODO: split yolo_box into 2 steps
# decode
# clip
boxes_list = []
scores_list = []
for i, out in enumerate(inputs['yolo_outs']):
boxes, scores = self.yolo_box(out, inputs['im_size'],
inputs['mask_anchors'][i], i,
"yolo_box_" + str(i))
for i, head_out in enumerate(yolo_head_out):
boxes, scores = self.yolo_box(head_out, im_size, mask_anchors[i],
self.num_classes, i)
boxes_list.append(boxes)
scores_list.append(fluid.layers.transpose(scores, perm=[0, 2, 1]))
yolo_boxes = fluid.layers.concat(boxes_list, axis=1)
yolo_scores = fluid.layers.concat(scores_list, axis=2)
nmsed_bbox = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
bbox = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
# TODO: parse the lod of nmsed_bbox
# default batch size is 1
bbox_nums = np.array([0, int(nmsed_bbox.shape[0])], dtype=np.int32)
outs = {"predicted_bbox_nums": bbox_nums, "predicted_bbox": nmsed_bbox}
return outs
bbox_num = np.array([int(bbox.shape[0])], dtype=np.int32)
return bbox, bbox_num
@register
......@@ -168,32 +166,18 @@ class AnchorRPN(object):
@register
class AnchorYOLO(object):
__inject__ = [
'anchor_generator', 'anchor_target_generator', 'anchor_post_process'
]
__inject__ = ['anchor_generator', 'anchor_post_process']
def __init__(self, anchor_generator, anchor_target_generator,
anchor_post_process):
def __init__(self, anchor_generator, anchor_post_process):
super(AnchorYOLO, self).__init__()
self.anchor_generator = anchor_generator
self.anchor_target_generator = anchor_target_generator
self.anchor_post_process = anchor_post_process
def __call__(self, inputs):
outs = self.generate_anchors(inputs)
return outs
def generate_anchors(self, inputs):
outs = self.anchor_generator(inputs['yolo_outs'])
outs['anchor_module'] = self
return outs
def generate_anchors_target(self, inputs):
outs = self.anchor_target_generator()
return outs
def __call__(self):
return self.anchor_generator()
def post_process(self, inputs):
return self.anchor_post_process(inputs)
def post_process(self, im_size, yolo_head_out, mask_anchors):
return self.anchor_post_process(im_size, yolo_head_out, mask_anchors)
@register
......
......@@ -37,7 +37,7 @@ class MaskFeat(Layer):
mask_conv.add_sublayer(
conv_name,
Conv2D(
num_channels=feat_in if j == 1 else feat_out,
num_channels=feat_in if j == 0 else feat_out,
num_filters=feat_out,
filter_size=3,
act='relu',
......
import paddle.fluid as fluid
import paddle
from paddle.fluid.dygraph import Layer
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from paddle.fluid.dygraph import Sequential
from ppdet.core.workspace import register
from ..backbone.darknet import ConvBNLayer
class YoloDetBlock(Layer):
def __init__(self, ch_in, channel):
def __init__(self, ch_in, channel, name):
super(YoloDetBlock, self).__init__()
self.ch_in = ch_in
self.channel = channel
assert channel % 2 == 0, \
"channel {} cannot be divided by 2".format(channel)
self.conv0 = ConvBNLayer(
ch_in=ch_in, ch_out=channel, filter_size=1, stride=1, padding=0)
self.conv1 = ConvBNLayer(
ch_in=channel,
ch_out=channel * 2,
filter_size=3,
stride=1,
padding=1)
self.conv2 = ConvBNLayer(
ch_in=channel * 2,
ch_out=channel,
filter_size=1,
stride=1,
padding=0)
self.conv3 = ConvBNLayer(
ch_in=channel,
ch_out=channel * 2,
filter_size=3,
stride=1,
padding=1)
self.route = ConvBNLayer(
ch_in=channel * 2,
ch_out=channel,
filter_size=1,
stride=1,
padding=0)
conv_def = [
['conv0', ch_in, channel, 1, '.0.0'],
['conv1', channel, channel * 2, 3, '.0.1'],
['conv2', channel * 2, channel, 1, '.1.0'],
['conv3', channel, channel * 2, 3, '.1.1'],
['route', channel * 2, channel, 1, '.2'],
#['tip', channel, channel * 2, 3],
]
self.conv_module = Sequential()
for idx, (conv_name, ch_in, ch_out, filter_size,
post_name) in enumerate(conv_def):
self.conv_module.add_sublayer(
conv_name,
ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=filter_size,
padding=(filter_size - 1) // 2,
name=name + post_name))
self.tip = ConvBNLayer(
ch_in=channel,
ch_out=channel * 2,
filter_size=3,
stride=1,
padding=1)
padding=1,
name=name + '.tip')
def forward(self, inputs):
out = self.conv0(inputs)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
route = self.route(out)
route = self.conv_module(inputs)
tip = self.tip(route)
return route, tip
class Upsample(Layer):
def __init__(self, scale=2):
super(Upsample, self).__init__()
self.scale = scale
def forward(self, inputs):
# get dynamic upsample output shape
shape_nchw = fluid.layers.shape(inputs)
shape_hw = fluid.layers.slice(
shape_nchw, axes=[0], starts=[2], ends=[4])
shape_hw.stop_gradient = True
in_shape = fluid.layers.cast(shape_hw, dtype='int32')
out_shape = in_shape * self.scale
out_shape.stop_gradient = True
# reisze by actual_shape
out = fluid.layers.resize_nearest(
input=inputs, scale=self.scale, actual_shape=out_shape)
return out
@register
class YOLOFeat(Layer):
def __init__(self, feat_in_list=[1024, 768, 384]):
__shared__ = ['num_levels']
def __init__(self, feat_in_list=[1024, 768, 384], num_levels=3):
super(YOLOFeat, self).__init__()
self.feat_in_list = feat_in_list
self.yolo_blocks = []
self.route_blocks = []
for i in range(3):
self.num_levels = num_levels
for i in range(self.num_levels):
name = 'yolo_block.{}'.format(i)
yolo_block = self.add_sublayer(
"yolo_det_block_%d" % (i),
name,
YoloDetBlock(
feat_in_list[i], channel=512 // (2**i)))
feat_in_list[i], channel=512 // (2**i), name=name))
self.yolo_blocks.append(yolo_block)
if i < 2:
if i < self.num_levels - 1:
name = 'yolo_transition.{}'.format(i)
route = self.add_sublayer(
"route_%d" % i,
name,
ConvBNLayer(
ch_in=512 // (2**i),
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0))
padding=0,
name=name))
self.route_blocks.append(route)
self.upsample = Upsample()
def forward(self, inputs):
def forward(self, body_feats):
assert len(body_feats) == self.num_levels
body_feats = body_feats[::-1]
yolo_feats = []
for i, block in enumerate(inputs['darknet_outs']):
for i, block in enumerate(body_feats):
if i > 0:
block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self.yolo_blocks[i](block)
yolo_feats.append(tip)
if i < 2:
if i < self.num_levels - 1:
route = self.route_blocks[i](route)
route = self.upsample(route)
route = fluid.layers.resize_nearest(route, scale=2.)
outs = {'yolo_feat': yolo_feats}
return outs
return yolo_feats
@register
class YOLOv3Head(Layer):
__shared__ = ['num_classes']
__shared__ = ['num_classes', 'num_levels', 'use_fine_grained_loss']
__inject__ = ['yolo_feat']
def __init__(self, yolo_feat, num_classes=80, anchor_per_position=3):
def __init__(self,
yolo_feat,
num_classes=80,
anchor_per_position=3,
num_levels=3,
use_fine_grained_loss=False,
ignore_thresh=0.7,
downsample=32,
label_smooth=True):
super(YOLOv3Head, self).__init__()
self.num_classes = num_classes
self.anchor_per_position = anchor_per_position
self.yolo_feat = yolo_feat
self.yolo_outs = []
for i in range(3):
self.num_levels = num_levels
self.use_fine_grained_loss = use_fine_grained_loss
self.ignore_thresh = ignore_thresh
self.downsample = downsample
self.label_smooth = label_smooth
self.yolo_out_list = []
for i in range(num_levels):
# TODO: optim here
#num_filters = len(cfg.anchor_masks[i]) * (self.num_classes + 5)
num_filters = self.anchor_per_position * (self.num_classes + 5)
name = 'yolo_output.{}'.format(i)
yolo_out = self.add_sublayer(
"yolo_out_%d" % (i),
name,
Conv2D(
num_channels=1024 // (2**i),
num_filters=num_filters,
......@@ -152,44 +137,38 @@ class YOLOv3Head(Layer):
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
param_attr=ParamAttr(name=name + '.conv.weights'),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.))))
self.yolo_outs.append(yolo_out)
name=name + '.conv.bias', regularizer=L2Decay(0.))))
self.yolo_out_list.append(yolo_out)
def forward(self, body_feats):
assert len(body_feats) == self.num_levels
yolo_feats = self.yolo_feat(body_feats)
yolo_head_out = []
for i, feat in enumerate(yolo_feats):
yolo_out = self.yolo_out_list[i](feat)
yolo_head_out.append(yolo_out)
return yolo_head_out
def loss(self, inputs, head_out, anchors, anchor_masks, mask_anchors):
if self.use_fine_grained_loss:
raise NotImplementedError
def forward(self, inputs):
outs = self.yolo_feat(inputs)
x = outs['yolo_feat']
yolo_out_list = []
for i, yolo_f in enumerate(x):
yolo_out = self.yolo_outs[i](yolo_f)
yolo_out_list.append(yolo_out)
outs.update({"yolo_outs": yolo_out_list})
return outs
def loss(self, inputs):
if callable(inputs['anchor_module']):
yolo_targets = inputs['anchor_module'].generate_anchors_target(
inputs)
yolo_losses = []
for i, out in enumerate(inputs['yolo_outs']):
# TODO: split yolov3_loss into small ops
# 1. compute target 2. loss
for i, out in enumerate(head_out):
loss = fluid.layers.yolov3_loss(
x=out,
gt_box=inputs['gt_bbox'],
gt_label=inputs['gt_class'],
gt_score=inputs['gt_score'],
anchors=inputs['anchors'],
anchor_mask=inputs['anchor_masks'][i],
anchors=anchors,
anchor_mask=anchor_masks[i],
class_num=self.num_classes,
ignore_thresh=yolo_targets['ignore_thresh'],
downsample_ratio=yolo_targets['downsample_ratio'] // 2**i,
use_label_smooth=yolo_targets['label_smooth'],
ignore_thresh=self.ignore_thresh,
downsample_ratio=self.downsample // 2**i,
use_label_smooth=self.label_smooth,
name='yolo_loss_' + str(i))
loss = fluid.layers.reduce_mean(loss)
yolo_losses.append(loss)
yolo_loss = sum(yolo_losses)
return yolo_loss
return {'loss': sum(yolo_losses)}
......@@ -99,20 +99,16 @@ class AnchorGeneratorYOLO(object):
self.anchors = anchors
self.anchor_masks = anchor_masks
def __call__(self, yolo_outs):
def __call__(self):
anchor_num = len(self.anchors)
mask_anchors = []
for i, _ in enumerate(yolo_outs):
for i in range(len(self.anchor_masks)):
mask_anchor = []
for m in self.anchor_masks[i]:
mask_anchor.append(self.anchors[2 * m])
mask_anchor.append(self.anchors[2 * m + 1])
assert m < anchor_num, "anchor mask index overflow"
mask_anchor.extend(self.anchors[2 * m:2 * m + 2])
mask_anchors.append(mask_anchor)
outs = {
"anchors": self.anchors,
"anchor_masks": self.anchor_masks,
"mask_anchors": mask_anchors
}
return outs
return self.anchors, self.anchor_masks, mask_anchors
@register
......@@ -305,6 +301,7 @@ class RoIExtractor(object):
self.canconical_level,
self.canonical_size,
rois_num=rois_num)
rois_feat_list = []
for lvl in range(self.start_level, self.end_level + 1):
roi_feat = fluid.layers.roi_align(
......@@ -381,24 +378,19 @@ class MultiClassNMS(object):
@register
@serializable
class YOLOBox(object):
__shared__ = ['num_classes']
def __init__(
self,
num_classes=80,
conf_thresh=0.005,
downsample_ratio=32,
clip_bbox=True, ):
self.num_classes = num_classes
self.conf_thresh = conf_thresh
self.downsample_ratio = downsample_ratio
self.clip_bbox = clip_bbox
def __call__(self, x, img_size, anchors, stage=0, name=None):
outs = fluid.layers.yolo_box(x, img_size, anchors, self.num_classes,
def __call__(self, x, img_size, anchors, num_classes, stage=0):
outs = fluid.layers.yolo_box(x, img_size, anchors, num_classes,
self.conf_thresh, self.downsample_ratio //
2**stage, self.clip_bbox, name)
2**stage, self.clip_bbox)
return outs
......
......@@ -82,7 +82,7 @@ class LinearWarmup(object):
def __call__(self, base_lr):
boundary = []
value = []
for i in range(self.steps):
for i in range(self.steps + 1):
alpha = i / self.steps
factor = self.start_factor * (1 - alpha) + alpha
lr = base_lr * factor
......
......@@ -190,7 +190,7 @@ def nms_with_decode(bboxes,
rois_n = bboxes_v[start:end, :] # box
rois_n = rois_n / im_info[i][2] # scale
rois_n = delta2bbox(bbox_deltas_n, rois_n, variance_v)
rois_n = clip_bbox(rois_n, im_info[i][:2] / im_info[i][2])
rois_n = clip_bbox(rois_n, np.round(im_info[i][:2] / im_info[i][2]))
cls_boxes = [[] for _ in range(class_nums)]
scores_n = bbox_probs_v[start:end, :]
for j in range(1, class_nums):
......
......@@ -136,7 +136,8 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
k = 0
for i in range(len(bbox_nums)):
image_id = int(image_id[i][0])
if bboxes.shape == (1, 1):
continue
det_nums = bbox_nums[i]
for j in range(det_nums):
dt = bboxes[k]
......
......@@ -3,10 +3,12 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import errno
import os
import time
import re
import numpy as np
import paddle
import paddle.fluid as fluid
from .download import get_weights_path
......@@ -88,9 +90,10 @@ def load_dygraph_ckpt(model,
return model
def save_dygraph_ckpt(model, optimizer, save_dir):
def save_dygraph_ckpt(model, optimizer, save_dir, save_name):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
fluid.dygraph.save_dygraph(model.state_dict(), save_dir)
fluid.dygraph.save_dygraph(optimizer.state_dict(), save_dir)
save_path = os.path.join(save_dir, save_name)
fluid.dygraph.save_dygraph(model.state_dict(), save_path)
fluid.dygraph.save_dygraph(optimizer.state_dict(), save_path)
print("Save checkpoint:", save_dir)
......@@ -20,6 +20,10 @@ from ppdet.utils.cli import ArgsParser
from ppdet.utils.eval_utils import coco_eval_results
from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def parse_args():
......@@ -58,22 +62,26 @@ def run(FLAGS, cfg):
# Run Eval
outs_res = []
start_time = time.time()
sample_num = 0
for iter_id, data in enumerate(eval_reader()):
start_time = time.time()
# forward
model.eval()
outs = model(data, cfg['EvalReader']['inputs_def']['fields'], 'infer')
outs_res.append(outs)
# log
cost_time = time.time() - start_time
print("Eval iter: {}, time: {}".format(iter_id, cost_time))
sample_num += len(data)
if iter_id % 100 == 0:
logger.info("Eval iter: {}".format(iter_id))
cost_time = time.time() - start_time
logger.info('Total sample number: {}, averge FPS: {}'.format(
sample_num, sample_num / cost_time))
# Metric
coco_eval_results(
outs_res,
include_mask=True if 'MaskHead' in cfg else False,
include_mask=True if getattr(cfg, 'MaskHead', None) else False,
dataset=cfg['EvalReader']['dataset'])
......
......@@ -113,7 +113,7 @@ def run(FLAGS, cfg):
optimizer,
cfg.pretrain_weights,
ckpt_type=FLAGS.ckpt_type,
load_static_weights=cfg.load_static_weights)
load_static_weights=cfg.get('load_static_weights', False))
# Parallel Model
if ParallelEnv().nranks > 1:
......@@ -177,8 +177,8 @@ def run(FLAGS, cfg):
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_name = str(
iter_id) if iter_id != cfg.max_iters - 1 else "model_final"
save_dir = os.path.join(cfg.save_dir, cfg_name, save_name)
save_dygraph_ckpt(model, optimizer, save_dir)
save_dir = os.path.join(cfg.save_dir, cfg_name)
save_dygraph_ckpt(model, optimizer, save_dir, save_name)
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册