提交 8152934d 编写于 作者: D dengkaipeng

refine code.

上级 7c7b0c4d
...@@ -64,6 +64,16 @@ _C.pixel_means = [0.485, 0.456, 0.406] ...@@ -64,6 +64,16 @@ _C.pixel_means = [0.485, 0.456, 0.406]
# pixel std values # pixel std values
_C.pixel_stds = [0.229, 0.224, 0.225] _C.pixel_stds = [0.229, 0.224, 0.225]
# anchors box weight and height
_C.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
# anchor mask of each yolo layer
_C.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
# IoU threshold to ignore objectness loss of pred box
_C.ignore_thresh = .7
# #
# SOLVER options # SOLVER options
# #
......
...@@ -86,14 +86,11 @@ def add_DarkNet53_conv_body(body_input, is_test=True): ...@@ -86,14 +86,11 @@ def add_DarkNet53_conv_body(body_input, is_test=True):
conv1 = conv_bn_layer( conv1 = conv_bn_layer(
body_input, ch_out=32, filter_size=3, stride=1, padding=1, is_test=is_test, name="yolo_input") body_input, ch_out=32, filter_size=3, stride=1, padding=1, is_test=is_test, name="yolo_input")
downsample_ = downsample(conv1, ch_out=conv1.shape[1]*2, is_test=is_test, name="yolo_input.downsample") downsample_ = downsample(conv1, ch_out=conv1.shape[1]*2, is_test=is_test, name="yolo_input.downsample")
index = 2
blocks = [] blocks = []
for i, stage in enumerate(stages): for i, stage in enumerate(stages):
block = layer_warp(block_func, downsample_, 32 *(2**i), stage, is_test=is_test, name="stage.{}".format(i)) block = layer_warp(block_func, downsample_, 32 *(2**i), stage, is_test=is_test, name="stage.{}".format(i))
blocks.append(block) blocks.append(block)
index += 3 * stage
if i < len(stages) - 1: # do not downsaple in the last stage if i < len(stages) - 1: # do not downsaple in the last stage
downsample_ = downsample(block, ch_out=block.shape[1]*2, is_test=is_test, name="stage.{}.downsample".format(i)) downsample_ = downsample(block, ch_out=block.shape[1]*2, is_test=is_test, name="stage.{}.downsample".format(i))
index += 1
return blocks[-1:-4:-1] return blocks[-1:-4:-1]
...@@ -62,8 +62,6 @@ class YOLOv3(object): ...@@ -62,8 +62,6 @@ class YOLOv3(object):
self.outputs = [] self.outputs = []
self.losses = [] self.losses = []
self.downsample = 32 self.downsample = 32
self.ignore_thresh = .7
self.class_num = 80
def build_input(self): def build_input(self):
self.image_shape = [3, cfg.input_size, cfg.input_size] self.image_shape = [3, cfg.input_size, cfg.input_size]
...@@ -132,16 +130,8 @@ class YOLOv3(object): ...@@ -132,16 +130,8 @@ class YOLOv3(object):
route = upsample(route) route = upsample(route)
anchor_mask = [6,7,8,3,4,5,0,1,2] for i, out in enumerate(self.outputs):
anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] anchor_mask = cfg.anchor_masks[i]
for i,out in enumerate(self.outputs):
mask = anchor_mask[i*3 : (i+1)*3]
mask_anchors=[]
for m in mask:
mask_anchors.append(anchors[2 * m])
mask_anchors.append(anchors[2 * m + 1])
class_num = int(self.class_num)
if self.is_train: if self.is_train:
ignore_thresh = float(self.ignore_thresh) ignore_thresh = float(self.ignore_thresh)
...@@ -150,20 +140,24 @@ class YOLOv3(object): ...@@ -150,20 +140,24 @@ class YOLOv3(object):
gtbox=self.gtbox, gtbox=self.gtbox,
gtlabel=self.gtlabel, gtlabel=self.gtlabel,
gtscore=self.gtscore, gtscore=self.gtscore,
anchors=anchors, anchors=cfg.anchors,
anchor_mask=mask, anchor_mask=anchor_mask,
class_num=class_num, class_num=cfg.class_num,
ignore_thresh=ignore_thresh, ignore_thresh=cfg.ignore_thresh,
downsample_ratio=self.downsample, downsample_ratio=self.downsample,
use_label_smooth=cfg.label_smooth, use_label_smooth=cfg.label_smooth,
name="yolo_loss"+str(i)) name="yolo_loss"+str(i))
self.losses.append(fluid.layers.reduce_mean(loss)) self.losses.append(fluid.layers.reduce_mean(loss))
else: else:
mask_anchors=[]
for m in anchor_mask:
mask_anchors.append(cfg.anchors[2 * m])
mask_anchors.append(cfg.anchors[2 * m + 1])
boxes, scores = fluid.layers.yolo_box( boxes, scores = fluid.layers.yolo_box(
x=out, x=out,
img_size=self.im_shape, img_size=self.im_shape,
anchors=mask_anchors, anchors=mask_anchors,
class_num=class_num, class_num=cfg.class_num,
conf_thresh=cfg.valid_thresh, conf_thresh=cfg.valid_thresh,
downsample_ratio=self.downsample, downsample_ratio=self.downsample,
name="yolo_box"+str(i)) name="yolo_box"+str(i))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册