未验证 提交 204bcbdf 编写于 作者: S sucuicong 提交者: GitHub

a version of ppyolo for EdgeBoard (#4243)

* a version of ppyolo for EdgeBoard

* a version of ppyolo for EdgeBoard
上级 e83c3ecf
......@@ -167,7 +167,7 @@ class DecodeCache(BaseOperator):
'''
super(DecodeCache, self).__init__()
self.use_cache = False if cache_root is None else True
self.use_cache = False if cache_root is None else True
self.cache_root = cache_root
if cache_root is not None:
......@@ -175,7 +175,8 @@ class DecodeCache(BaseOperator):
def apply(self, sample, context=None):
if self.use_cache and os.path.exists(self.cache_path(self.cache_root, sample['im_file'])):
if self.use_cache and os.path.exists(
self.cache_path(self.cache_root, sample['im_file'])):
path = self.cache_path(self.cache_root, sample['im_file'])
im = self.load(path)
......@@ -191,7 +192,8 @@ class DecodeCache(BaseOperator):
sample['ori_image'] = im
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if self.use_cache and not os.path.exists(self.cache_path(self.cache_root, sample['im_file'])):
if self.use_cache and not os.path.exists(
self.cache_path(self.cache_root, sample['im_file'])):
path = self.cache_path(self.cache_root, sample['im_file'])
self.dump(im, path)
......@@ -212,7 +214,7 @@ class DecodeCache(BaseOperator):
def load(path):
with open(path, 'rb') as f:
im = pickle.load(f)
return im
return im
@staticmethod
def dump(obj, path):
......@@ -227,6 +229,7 @@ class DecodeCache(BaseOperator):
finally:
MUTEX.release()
@register_op
class Permute(BaseOperator):
def __init__(self):
......
architecture: YOLOv3
use_gpu: true
max_iters: 500000
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_vd_pretrained.tar
weights: output/ppyolo_eb/best_model
num_classes: 80
use_fine_grained_loss: true
log_iter: 1000
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet_EB
yolo_head: EBHead
ResNet_EB:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 34
variant: d
feature_maps: [3, 4, 5]
EBHead:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
YOLOv3Loss:
ignore_thresh: 0.7
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 320000
- 450000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
architecture: YOLOv3
use_gpu: true
max_iters: 70000
log_smooth_window: 20
save_dir: output
snapshot_iter: 3000
metric: VOC
map_type: integral
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_vd_pretrained.tar
weights: output/ppyolo_eb_voc/best_model
num_classes: 20
use_fine_grained_loss: true
log_iter: 1000
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet_EB
yolo_head: EBHead
ResNet_EB:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 34
variant: d
feature_maps: [3, 4, 5]
EBHead:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
YOLOv3Loss:
ignore_thresh: 0.7
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 35000
- 60000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
TrainReader:
dataset:
!VOCDataSet
dataset_dir: dataset/voc
anno_path: trainval.txt
use_default_label: false
with_background: false
mixup_epoch: 200
batch_size: 8
EvalReader:
inputs_def:
image_shape: [3, 608, 608]
fields: ['image', 'im_size', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult']
num_max_boxes: 50
dataset:
!VOCDataSet
dataset_dir: dataset/voc
anno_path: test.txt
use_default_label: false
with_background: false
TestReader:
dataset:
!ImageFolder
use_default_label: false
with_background: false
......@@ -22,6 +22,7 @@ from . import corner_head
from . import efficient_head
from . import ttf_head
from . import solov2_head
from . import eb_head
from .rpn_head import *
from .yolo_head import *
......@@ -31,3 +32,4 @@ from .corner_head import *
from .efficient_head import *
from .ttf_head import *
from .solov2_head import *
from .eb_head import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from ppdet.core.workspace import register
__all__ = ['EBHead']
@register
class EBHead(object):
"""
Head block for pp-yolo-eb, ppyolo for EdgeBoard : https://ai.baidu.com/ai-doc/HWCE/Yk3b86gvp
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
"""
__inject__ = ['yolo_loss', 'nms']
__shared__ = ['num_classes', 'weight_prefix_name']
def __init__(self,
norm_decay=0.,
num_classes=80,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
drop_block=False,
block_size=3,
keep_prob=0.9,
yolo_loss="YOLOv3Loss",
nms=MultiClassNMS(
score_threshold=0.01,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.45,
background_label=-1).__dict__,
weight_prefix_name=''):
self.norm_decay = norm_decay
self.num_classes = num_classes
self.anchor_masks = anchor_masks
self._parse_anchors(anchors)
self.yolo_loss = yolo_loss
self.nms = nms
self.prefix_name = weight_prefix_name
self.drop_block = drop_block
self.block_size = block_size
self.keep_prob = keep_prob
if isinstance(nms, dict):
self.nms = MultiClassNMS(**nms)
def _conv_bn(self,
input,
ch_out,
filter_size,
stride,
padding,
act='leaky',
is_test=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
bn_param_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
bn_bias_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
out = fluid.layers.batch_norm(
input=conv,
act=None,
is_test=is_test,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
def _detection_block(self, input, channel, is_test=True, name=None):
assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
.format(channel, name)
conv = input
conv = self._conv_bn(
conv,
channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.0'.format(name))
for j in range(4):
conv = self._conv_bn(
conv,
channel,
filter_size=3,
stride=1,
padding=1,
is_test=is_test,
name='{}.{}.1'.format(name, j))
if j == 1:
route = conv
return route, conv
def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
return out
def _pool_concat(self, input):
pool1 = fluid.layers.pool2d(
input=input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
pool2 = fluid.layers.pool2d(
input=input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='avg')
out = fluid.layers.concat(input=[pool1, pool2], axis=1)
return out
def _parse_anchors(self, anchors):
"""
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
"""
self.anchors = []
self.mask_anchors = []
assert len(anchors) > 0, "ANCHORS not set."
assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
for anchor in anchors:
assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
self.anchors.extend(anchor)
anchor_num = len(anchors)
for masks in self.anchor_masks:
self.mask_anchors.append([])
for mask in masks:
assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask])
def _get_outputs(self, input, is_train=True):
"""
Get ppyolo_eb head output
Args:
input (list): List of Variables, output of backbone stages
is_train (bool): whether in train or test mode
Returns:
outputs (list): Variables of each output layer
"""
outputs = []
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
blocks = input[-1:-out_layer_num - 1:-1]
filters_num1 = blocks[1].shape[1] // 2
blk0 = self._pool_concat(blocks[2])
blk0 = self._conv_bn(
blk0,
filters_num1,
filter_size=1,
stride=1,
padding=0,
is_test=False,
name='channel_fusion_1')
blk1 = fluid.layers.concat(input=[blk0, blocks[1]], axis=1)
filters_num2 = blocks[0].shape[1] // 2
blk = self._conv_bn(
blk1,
filters_num2,
filter_size=1,
stride=1,
padding=0,
is_test=False,
name='channel_fusion_2')
blk2 = self._conv_bn(
blk,
filters_num2,
filter_size=3,
stride=1,
padding=1,
is_test=False,
name='feature_fusion')
blk2 = self._pool_concat(blk2)
blk2 = self._conv_bn(
blk2,
filters_num2,
filter_size=1,
stride=1,
padding=0,
is_test=False,
name='channel_fusion_3')
blk3 = fluid.layers.concat(input=[blk2, blocks[0]], axis=1)
blocks = [blk3, blk1, blocks[2]]
route = None
for i, block in enumerate(blocks):
if i > 0: # perform concat in first 2 detection_block
block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self._detection_block(
block,
channel=512 // (2**i),
is_test=(not is_train),
name=self.prefix_name + "yolo_block.{}".format(i))
# out channel number = mask_num * (5 + class_num)
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
with fluid.name_scope('yolo_output'):
block_out = fluid.layers.conv2d(
input=tip,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(
name=self.prefix_name +
"yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name=self.prefix_name +
"yolo_output.{}.conv.bias".format(i)))
outputs.append(block_out)
if i < len(blocks) - 1:
# do not perform upsample in the last detection_block
route = self._conv_bn(
input=route,
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not is_train),
name=self.prefix_name + "yolo_transition.{}".format(i))
# upsample
route = self._upsample(route)
return outputs
def get_loss(self, input, gt_box, gt_label, gt_score, targets):
"""
Get final loss of network of ppyolo_eb.
Args:
input (list): List of Variables, output of backbone stages
gt_box (Variable): The ground-truth boudding boxes.
gt_label (Variable): The ground-truth class labels.
gt_score (Variable): The ground-truth boudding boxes mixup scores.
targets ([Variables]): List of Variables, the targets for yolo
loss calculatation.
Returns:
loss (Variable): The loss Variable of ppyolo_eb network.
"""
outputs = self._get_outputs(input, is_train=True)
return self.yolo_loss(outputs, gt_box, gt_label, gt_score, targets,
self.anchors, self.anchor_masks,
self.mask_anchors, self.num_classes,
self.prefix_name)
def get_prediction(self, input, im_size, exclude_nms=False):
"""
Get prediction result of ppyolo_eb network
Args:
input (list): List of Variables, output of backbone stages
im_size (Variable): Variable of size([h, w]) of each image
Returns:
pred (Variable): The prediction result after non-max suppress.
"""
outputs = self._get_outputs(input, is_train=False)
boxes = []
scores = []
downsample = 32
for i, output in enumerate(outputs):
box, score = fluid.layers.yolo_box(
x=output,
img_size=im_size,
anchors=self.mask_anchors[i],
class_num=self.num_classes,
conf_thresh=self.nms.score_threshold,
downsample_ratio=downsample,
name=self.prefix_name + "yolo_box" + str(i))
boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
downsample //= 2
yolo_boxes = fluid.layers.concat(boxes, axis=1)
yolo_scores = fluid.layers.concat(scores, axis=2)
# Only for benchmark, postprocess(NMS) is not needed
if exclude_nms:
return {'bbox': yolo_boxes, 'score': yolo_scores}
pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
return {'bbox': pred}
......@@ -35,6 +35,7 @@ from . import bifpn
from . import cspdarknet
from . import acfpn
from . import ghostnet
from . import resnet_eb
from .resnet import *
from .resnext import *
......@@ -57,3 +58,4 @@ from .bifpn import *
from .cspdarknet import *
from .acfpn import *
from .ghostnet import *
from .resnet_eb import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import Constant
from ppdet.core.workspace import register, serializable
from numbers import Integral
from .name_adapter import NameAdapter
__all__ = ['ResNet_EB']
@register
@serializable
class ResNet_EB(object):
"""
modified ResNet, especially for EdgeBoard: https://ai.baidu.com/ai-doc/HWCE/Yk3b86gvp
"""
__shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
def __init__(self,
depth=50,
freeze_at=2,
norm_type='affine_channel',
freeze_norm=True,
norm_decay=0.,
variant='b',
feature_maps=[2, 3, 4, 5],
weight_prefix_name='',
lr_mult_list=[1., 1., 1., 1.]):
super(ResNet_EB, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
assert depth in [18, 34, 50, 101, 152, 200], \
"depth {} not in [18, 34, 50, 101, 152, 200]"
assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
assert len(feature_maps) > 0, "need one or more feature maps"
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert len(lr_mult_list
) == 4, "lr_mult_list length must be 4 but got {}".format(
len(lr_mult_list))
self.depth = depth
self.freeze_at = freeze_at
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.variant = variant
self._model_type = 'ResNet'
self.feature_maps = feature_maps
self.depth_cfg = {
18: ([2, 2, 2, 2], self.basicblock),
34: ([3, 4, 6, 3], self.basicblock),
50: ([3, 4, 6, 3], self.bottleneck),
101: ([3, 4, 23, 3], self.bottleneck),
152: ([3, 8, 36, 3], self.bottleneck),
200: ([3, 12, 48, 3], self.bottleneck),
}
self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64
self.na = NameAdapter(self)
self.prefix_name = weight_prefix_name
self.lr_mult_list = lr_mult_list
# var denoting curr stage
self.stage_num = -1
def _conv_offset(self,
input,
filter_size,
stride,
padding,
act=None,
name=None):
out_channel = filter_size * filter_size * 3
out = fluid.layers.conv2d(
input,
num_filters=out_channel,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=ParamAttr(
initializer=Constant(0.0), name=name + ".w_0"),
bias_attr=ParamAttr(
initializer=Constant(0.0), name=name + ".b_0"),
act=act,
name=name)
return out
def _conv_norm(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None,
dcn_v2=False):
_name = self.prefix_name + name if self.prefix_name != '' else name
# need fine lr for distilled model, default as 1.0
lr_mult = 1.0
mult_idx = max(self.stage_num - 2, 0)
mult_idx = min(self.stage_num - 2, 3)
lr_mult = self.lr_mult_list[mult_idx]
if not dcn_v2:
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(
name=_name + "_weights", learning_rate=lr_mult),
bias_attr=False,
name=_name + '.conv2d.output.1')
else:
# select deformable conv"
offset_mask = self._conv_offset(
input=input,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
act=None,
name=_name + "_conv_offset")
offset_channel = filter_size**2 * 2
mask_channel = filter_size**2
offset, mask = fluid.layers.split(
input=offset_mask,
num_or_sections=[offset_channel, mask_channel],
dim=1)
mask = fluid.layers.sigmoid(mask)
conv = fluid.layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
deformable_groups=1,
im2col_step=1,
param_attr=ParamAttr(
name=_name + "_weights", learning_rate=lr_mult),
bias_attr=False,
name=_name + ".conv2d.output.1")
bn_name = self.na.fix_conv_norm_name(name)
bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
norm_lr = 0. if self.freeze_norm else lr_mult
norm_decay = self.norm_decay
pattr = ParamAttr(
name=bn_name + '_scale',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
battr = ParamAttr(
name=bn_name + '_offset',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
if self.norm_type in ['bn', 'sync_bn']:
global_stats = True if self.freeze_norm else False
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif self.norm_type == 'affine_channel':
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=pattr,
default_initializer=fluid.initializer.Constant(1.))
bias = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=battr,
default_initializer=fluid.initializer.Constant(0.))
out = fluid.layers.affine_channel(
x=conv, scale=scale, bias=bias, act=act)
if self.freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
def _shortcut(self, input, ch_out, stride, is_first, name):
max_pooling_in_short_cut = self.variant == 'd'
ch_in = input.shape[1]
# the naming rule is same as pretrained weight
name = self.na.fix_shortcut_name(name)
std_senet = getattr(self, 'std_senet', False)
if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
if std_senet:
if is_first:
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return self._conv_norm(input, ch_out, 3, stride, name=name)
if max_pooling_in_short_cut and not is_first:
input1 = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
ceil_mode=True,
pool_type='max')
input2 = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
ceil_mode=True,
pool_type='avg')
input = fluid.layers.elementwise_add(
x=input1, y=input2, name=name + ".pool.add")
return self._conv_norm(input, ch_out, 1, 1, name=name)
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck(self,
input,
num_filters,
stride,
is_first,
name,
dcn_v2=False,
gcb=False,
gcb_name=None):
assert dcn_v2 is False, "Not implemented in EdgeBoard yet."
assert gcb is False, "Not implemented in EdgeBoard yet."
if self.variant == 'a':
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
# ResNeXt
groups = getattr(self, 'groups', 1)
group_width = getattr(self, 'group_width', -1)
if groups == 1:
expand = 4
elif (groups * group_width) == 256:
expand = 1
else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
num_filters = num_filters // 2
expand = 2
conv_name1, conv_name2, conv_name3, \
shortcut_name = self.na.fix_bottleneck_name(name)
std_senet = getattr(self, 'std_senet', False)
if std_senet:
conv_def = [
[int(num_filters / 2), 1, stride1, 'relu', 1, conv_name1],
[num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]
]
else:
conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1],
[num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
residual = input
for i, (c, k, s, act, g, _name) in enumerate(conv_def):
residual = self._conv_norm(
input=residual,
num_filters=c,
filter_size=k,
stride=s,
act=act,
groups=g,
name=_name,
dcn_v2=False)
short = self._shortcut(
input,
num_filters * expand,
stride,
is_first=is_first,
name=shortcut_name)
# Squeeze-and-Excitation
if callable(getattr(self, '_squeeze_excitation', None)):
residual = self._squeeze_excitation(
input=residual, num_channels=num_filters, name='fc' + name)
return fluid.layers.elementwise_add(
x=short, y=residual, act='relu', name=name + ".add.output.5")
def basicblock(self,
input,
num_filters,
stride,
is_first,
name,
dcn_v2=False,
gcb=False,
gcb_name=None):
assert dcn_v2 is False, "Not implemented in EdgeBoard yet."
assert gcb is False, "Not implemented EdgeBoard yet."
conv0 = self._conv_norm(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self._conv_norm(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self._shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def layer_warp(self, input, stage_num):
"""
Args:
input (Variable): input variable.
stage_num (int): the stage number, should be 2, 3, 4, 5
Returns:
The last variable in endpoint-th stage.
"""
assert stage_num in [2, 3, 4, 5]
self.stage_num = stage_num
stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2]
ch_out = self.stage_filters[stage_num - 2]
is_first = False if stage_num != 2 else True
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv = input
for i in range(count):
conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
if self.depth < 50:
is_first = True if i == 0 and stage_num == 2 else False
conv = block_func(
input=conv,
num_filters=ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
is_first=is_first,
name=conv_name,
dcn_v2=False,
gcb=False,
gcb_name=None)
return conv
def c1_stage(self, input):
out_chan = self._c1_out_chan_num
conv1_name = self.na.fix_c1_stage_name()
if self.variant in ['c', 'd']:
conv_def = [
[out_chan // 2, 3, 2, "conv1_1"],
[out_chan // 2, 3, 1, "conv1_2"],
[out_chan, 3, 1, "conv1_3"],
]
else:
conv_def = [[out_chan, 7, 2, conv1_name]]
for (c, k, s, _name) in conv_def:
input = self._conv_norm(
input=input,
num_filters=c,
filter_size=k,
stride=s,
act='relu',
name=_name)
output = fluid.layers.pool2d(
input=input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
return output
def __call__(self, input):
assert isinstance(input, Variable)
assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
"feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
res_endpoints = []
res = input
feature_maps = self.feature_maps
severed_head = getattr(self, 'severed_head', False)
if not severed_head:
res = self.c1_stage(res)
feature_maps = range(2, max(self.feature_maps) + 1)
for i in feature_maps:
res = self.layer_warp(res, i)
if i in self.feature_maps:
res_endpoints.append(res)
if self.freeze_at >= i:
res.stop_gradient = True
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册