Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
0e40029f
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0e40029f
编写于
8月 14, 2020
作者:
G
Guanghua Yu
提交者:
GitHub
8月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ssdlite-ghostnet model (#1133)
上级
c387335a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
527 addition
and
0 deletion
+527
-0
configs/ssd/ssdlite_ghostnet.yml
configs/ssd/ssdlite_ghostnet.yml
+162
-0
docs/MODEL_ZOO.md
docs/MODEL_ZOO.md
+1
-0
docs/MODEL_ZOO_cn.md
docs/MODEL_ZOO_cn.md
+1
-0
ppdet/modeling/backbones/__init__.py
ppdet/modeling/backbones/__init__.py
+2
-0
ppdet/modeling/backbones/ghostnet.py
ppdet/modeling/backbones/ghostnet.py
+361
-0
未找到文件。
configs/ssd/ssdlite_ghostnet.yml
0 → 100644
浏览文件 @
0e40029f
architecture
:
SSD
use_gpu
:
true
max_iters
:
400000
snapshot_iter
:
20000
log_smooth_window
:
20
log_iter
:
20
metric
:
COCO
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/GhostNet_x1_3_ssld_pretrained.tar
save_dir
:
output
weights
:
output/ssdlite_ghostnet/model_final
# 80(label_class) + 1(background)
num_classes
:
81
SSD
:
backbone
:
GhostNet
multi_box_head
:
SSDLiteMultiBoxHead
output_decoder
:
background_label
:
0
keep_top_k
:
200
nms_eta
:
1.0
nms_threshold
:
0.45
nms_top_k
:
400
score_threshold
:
0.01
GhostNet
:
scale
:
1.3
extra_block_filters
:
[[
256
,
512
],
[
128
,
256
],
[
128
,
256
],
[
64
,
128
]]
feature_maps
:
[
5
,
7
,
8
,
9
,
10
,
11
]
conv_decay
:
0.00004
lr_mult_list
:
[
0.25
,
0.25
,
0.5
,
0.5
,
0.75
]
SSDLiteMultiBoxHead
:
aspect_ratios
:
[[
2.
],
[
2.
,
3.
],
[
2.
,
3.
],
[
2.
,
3.
],
[
2.
,
3.
],
[
2.
,
3.
]]
base_size
:
320
steps
:
[
16
,
32
,
64
,
107
,
160
,
320
]
flip
:
true
clip
:
true
max_ratio
:
95
min_ratio
:
20
offset
:
0.5
conv_decay
:
0.00004
LearningRate
:
base_lr
:
0.2
schedulers
:
-
!CosineDecay
max_iters
:
400000
-
!LinearWarmup
start_factor
:
0.33333
steps
:
2000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
TrainReader
:
inputs_def
:
image_shape
:
[
3
,
320
,
320
]
fields
:
[
'
image'
,
'
gt_bbox'
,
'
gt_class'
]
dataset
:
!COCODataSet
dataset_dir
:
dataset/coco
anno_path
:
annotations/instances_train2017.json
image_dir
:
train2017
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!RandomDistort
brightness_lower
:
0.875
brightness_upper
:
1.125
is_order
:
true
-
!RandomExpand
fill_value
:
[
123.675
,
116.28
,
103.53
]
-
!RandomCrop
allow_no_crop
:
false
-
!NormalizeBox
{}
-
!ResizeImage
interp
:
1
target_size
:
320
use_cv2
:
false
-
!RandomFlipImage
is_normalized
:
false
-
!NormalizeImage
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
is_scale
:
true
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
true
batch_size
:
64
shuffle
:
true
drop_last
:
true
# Number of working threads/processes. To speed up, can be set to 16 or 32 etc.
worker_num
:
8
# Size of shared memory used in result queue. After increasing `worker_num`, need expand `memsize`.
memsize
:
8G
# Buffer size for multi threads/processes.one instance in buffer is one batch data.
# To speed up, can be set to 64 or 128 etc.
bufsize
:
32
use_process
:
true
EvalReader
:
inputs_def
:
image_shape
:
[
3
,
320
,
320
]
fields
:
[
'
image'
,
'
gt_bbox'
,
'
gt_class'
,
'
im_shape'
,
'
im_id'
]
dataset
:
!COCODataSet
dataset_dir
:
dataset/coco
anno_path
:
annotations/instances_val2017.json
image_dir
:
val2017
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!NormalizeBox
{}
-
!ResizeImage
interp
:
1
target_size
:
320
use_cv2
:
false
-
!NormalizeImage
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
is_scale
:
true
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
8
worker_num
:
8
bufsize
:
32
use_process
:
false
TestReader
:
inputs_def
:
image_shape
:
[
3
,
320
,
320
]
fields
:
[
'
image'
,
'
im_id'
,
'
im_shape'
]
dataset
:
!ImageFolder
anno_path
:
annotations/instances_val2017.json
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!ResizeImage
interp
:
1
max_size
:
0
target_size
:
320
use_cv2
:
false
-
!NormalizeImage
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
is_scale
:
true
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
docs/MODEL_ZOO.md
浏览文件 @
0e40029f
...
...
@@ -200,6 +200,7 @@ results of image size 608/416/320 above. Deformable conv is added on stage 5 of
| MobileNet_v3 large | 320 | 64 | Cosine decay(40w) | - | 23.3 |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/mobile_models/ssdlite_mobilenet_v3_large.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ssd/ssdlite_mobilenet_v3_large.yml
)
|
| MobileNet_v3 small w/ FPN | 320 | 64 | Cosine decay(40w) | - | 18.9 |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/mobile_models/ssdlite_mobilenet_v3_small_fpn.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ssd/ssdlite_mobilenet_v3_small_fpn.yml
)
|
| MobileNet_v3 large w/ FPN | 320 | 64 | Cosine decay(40w) | - | 24.3 |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/mobile_models/ssdlite_mobilenet_v3_large_fpn.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ssd/ssdlite_mobilenet_v3_large_fpn.yml
)
|
| GhostNet | 320 | 64 | Cosine decay(40w) | - | 23.3 |
[
model
](
htts://paddlemodels.bj.bcebos.com/object_detection/mobile_models/ssdlite_ghostnet.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ssd/ssdlite_ghostnet.yml
)
|
**Notes:**
`SSDLite`
is trained in 8 GPU with total batch size as 512 and uses cosine decay strategy to train.
...
...
docs/MODEL_ZOO_cn.md
浏览文件 @
0e40029f
...
...
@@ -192,6 +192,7 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
| MobileNet_v3 large | 320 | 64 | Cosine decay(40w) | - | 23.3 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/mobile_models/ssdlite_mobilenet_v3_large.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ssd/ssdlite_mobilenet_v3_large.yml
)
|
| MobileNet_v3 small w/ FPN | 320 | 64 | Cosine decay(40w) | - | 18.9 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/mobile_models/ssdlite_mobilenet_v3_small_fpn.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ssd/ssdlite_mobilenet_v3_small_fpn.yml
)
|
| MobileNet_v3 large w/ FPN | 320 | 64 | Cosine decay(40w) | - | 24.3 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/mobile_models/ssdlite_mobilenet_v3_large_fpn.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ssd/ssdlite_mobilenet_v3_large_fpn.yml
)
|
| GhostNet | 320 | 64 | Cosine decay(40w) | - | 23.3 |
[
下载链接
](
htts://paddlemodels.bj.bcebos.com/object_detection/mobile_models/ssdlite_ghostnet.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ssd/ssdlite_ghostnet.yml
)
|
**注意事项:**
SSDLite模型使用学习率余弦衰减策略在8卡GPU下总batch size为512。
...
...
ppdet/modeling/backbones/__init__.py
浏览文件 @
0e40029f
...
...
@@ -34,6 +34,7 @@ from . import efficientnet
from
.
import
bifpn
from
.
import
cspdarknet
from
.
import
acfpn
from
.
import
ghostnet
from
.resnet
import
*
from
.resnext
import
*
...
...
@@ -55,3 +56,4 @@ from .efficientnet import *
from
.bifpn
import
*
from
.cspdarknet
import
*
from
.acfpn
import
*
from
.ghostnet
import
*
ppdet/modeling/backbones/ghostnet.py
0 → 100644
浏览文件 @
0e40029f
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
collections
import
OrderedDict
from
ppdet.core.workspace
import
register
__all__
=
[
"GhostNet"
]
@
register
class
GhostNet
(
object
):
"""
scale (float): scaling factor for convolution groups proportion of GhostNet.
feature_maps (list): index of stages whose feature maps are returned.
conv_decay (float): weight decay for convolution layer weights.
extra_block_filters (list): number of filter for each extra block.
lr_mult_list (list): learning rate ratio of different blocks, lower learning rate ratio
is need for pretrained model got using distillation(default as
[1.0, 1.0, 1.0, 1.0, 1.0]).
"""
def
__init__
(
self
,
scale
,
feature_maps
=
[
5
,
6
,
7
,
8
,
9
,
10
],
conv_decay
=
0.00001
,
extra_block_filters
=
[[
256
,
512
],
[
128
,
256
],
[
128
,
256
],
[
64
,
128
]],
lr_mult_list
=
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
],
freeze_norm
=
False
):
self
.
scale
=
scale
self
.
feature_maps
=
feature_maps
self
.
extra_block_filters
=
extra_block_filters
self
.
end_points
=
[]
self
.
block_stride
=
0
self
.
conv_decay
=
conv_decay
self
.
lr_mult_list
=
lr_mult_list
self
.
freeze_norm
=
freeze_norm
self
.
curr_stage
=
0
self
.
cfgs
=
[
# k, t, c, se, s
[
3
,
16
,
16
,
0
,
1
],
[
3
,
48
,
24
,
0
,
2
],
[
3
,
72
,
24
,
0
,
1
],
[
5
,
72
,
40
,
1
,
2
],
[
5
,
120
,
40
,
1
,
1
],
[
3
,
240
,
80
,
0
,
2
],
[
3
,
200
,
80
,
0
,
1
],
[
3
,
184
,
80
,
0
,
1
],
[
3
,
184
,
80
,
0
,
1
],
[
3
,
480
,
112
,
1
,
1
],
[
3
,
672
,
112
,
1
,
1
],
[
5
,
672
,
160
,
1
,
2
],
[
5
,
960
,
160
,
0
,
1
],
[
5
,
960
,
160
,
1
,
1
],
[
5
,
960
,
160
,
0
,
1
],
[
5
,
960
,
160
,
1
,
1
]
]
def
_conv_bn_layer
(
self
,
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
):
lr_idx
=
self
.
curr_stage
//
3
lr_idx
=
min
(
lr_idx
,
len
(
self
.
lr_mult_list
)
-
1
)
lr_mult
=
self
.
lr_mult_list
[
lr_idx
]
norm_lr
=
0.
if
self
.
freeze_norm
else
lr_mult
x
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
self
.
conv_decay
),
learning_rate
=
lr_mult
,
initializer
=
fluid
.
initializer
.
MSRA
(),
name
=
name
+
"_weights"
),
bias_attr
=
False
)
bn_name
=
name
+
"_bn"
x
=
fluid
.
layers
.
batch_norm
(
input
=
x
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
,
learning_rate
=
norm_lr
,
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
,
learning_rate
=
norm_lr
,
regularizer
=
L2Decay
(
0.0
)),
moving_mean_name
=
bn_name
+
"_mean"
,
moving_variance_name
=
name
+
"_variance"
)
return
x
def
se_block
(
self
,
input
,
num_channels
,
reduction_ratio
=
4
,
name
=
None
):
lr_idx
=
self
.
curr_stage
//
3
lr_idx
=
min
(
lr_idx
,
len
(
self
.
lr_mult_list
)
-
1
)
lr_mult
=
self
.
lr_mult_list
[
lr_idx
]
pool
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_type
=
'avg'
,
global_pooling
=
True
,
use_cudnn
=
False
)
stdv
=
1.0
/
math
.
sqrt
(
pool
.
shape
[
1
]
*
1.0
)
squeeze
=
fluid
.
layers
.
fc
(
input
=
pool
,
size
=
num_channels
//
reduction_ratio
,
act
=
'relu'
,
param_attr
=
ParamAttr
(
learning_rate
=
lr_mult
,
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
),
name
=
name
+
'_1_weights'
),
bias_attr
=
ParamAttr
(
name
=
name
+
'_1_offset'
,
learning_rate
=
lr_mult
))
stdv
=
1.0
/
math
.
sqrt
(
squeeze
.
shape
[
1
]
*
1.0
)
excitation
=
fluid
.
layers
.
fc
(
input
=
squeeze
,
size
=
num_channels
,
act
=
None
,
param_attr
=
ParamAttr
(
learning_rate
=
lr_mult
,
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
),
name
=
name
+
'_2_weights'
),
bias_attr
=
ParamAttr
(
name
=
name
+
'_2_offset'
,
learning_rate
=
lr_mult
))
excitation
=
fluid
.
layers
.
clip
(
x
=
excitation
,
min
=
0
,
max
=
1
)
se_scale
=
fluid
.
layers
.
elementwise_mul
(
x
=
input
,
y
=
excitation
,
axis
=
0
)
return
se_scale
def
depthwise_conv
(
self
,
input
,
output
,
kernel_size
,
stride
=
1
,
relu
=
False
,
name
=
None
):
return
self
.
_conv_bn_layer
(
input
=
input
,
num_filters
=
output
,
filter_size
=
kernel_size
,
stride
=
stride
,
groups
=
input
.
shape
[
1
],
act
=
"relu"
if
relu
else
None
,
name
=
name
+
"_depthwise"
)
def
ghost_module
(
self
,
input
,
output
,
kernel_size
=
1
,
ratio
=
2
,
dw_size
=
3
,
stride
=
1
,
relu
=
True
,
name
=
None
):
self
.
output
=
output
init_channels
=
int
(
math
.
ceil
(
output
/
ratio
))
new_channels
=
int
(
init_channels
*
(
ratio
-
1
))
primary_conv
=
self
.
_conv_bn_layer
(
input
=
input
,
num_filters
=
init_channels
,
filter_size
=
kernel_size
,
stride
=
stride
,
groups
=
1
,
act
=
"relu"
if
relu
else
None
,
name
=
name
+
"_primary_conv"
)
cheap_operation
=
self
.
_conv_bn_layer
(
input
=
primary_conv
,
num_filters
=
new_channels
,
filter_size
=
dw_size
,
stride
=
1
,
groups
=
init_channels
,
act
=
"relu"
if
relu
else
None
,
name
=
name
+
"_cheap_operation"
)
out
=
fluid
.
layers
.
concat
([
primary_conv
,
cheap_operation
],
axis
=
1
)
return
out
def
ghost_bottleneck
(
self
,
input
,
hidden_dim
,
output
,
kernel_size
,
stride
,
use_se
,
name
=
None
):
inp_channels
=
input
.
shape
[
1
]
x
=
self
.
ghost_module
(
input
=
input
,
output
=
hidden_dim
,
kernel_size
=
1
,
stride
=
1
,
relu
=
True
,
name
=
name
+
"_ghost_module_1"
)
if
self
.
block_stride
==
4
and
stride
==
2
:
self
.
block_stride
+=
1
if
self
.
block_stride
in
self
.
feature_maps
:
self
.
end_points
.
append
(
x
)
if
stride
==
2
:
x
=
self
.
depthwise_conv
(
input
=
x
,
output
=
hidden_dim
,
kernel_size
=
kernel_size
,
stride
=
stride
,
relu
=
False
,
name
=
name
+
"_depthwise"
)
if
use_se
:
x
=
self
.
se_block
(
input
=
x
,
num_channels
=
hidden_dim
,
name
=
name
+
"_se"
)
x
=
self
.
ghost_module
(
input
=
x
,
output
=
output
,
kernel_size
=
1
,
relu
=
False
,
name
=
name
+
"_ghost_module_2"
)
if
stride
==
1
and
inp_channels
==
output
:
shortcut
=
input
else
:
shortcut
=
self
.
depthwise_conv
(
input
=
input
,
output
=
inp_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
relu
=
False
,
name
=
name
+
"_shortcut_depthwise"
)
shortcut
=
self
.
_conv_bn_layer
(
input
=
shortcut
,
num_filters
=
output
,
filter_size
=
1
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
name
+
"_shortcut_conv"
)
return
fluid
.
layers
.
elementwise_add
(
x
=
x
,
y
=
shortcut
,
axis
=-
1
)
def
_extra_block_dw
(
self
,
input
,
num_filters1
,
num_filters2
,
stride
,
name
=
None
):
pointwise_conv
=
self
.
_conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
int
(
num_filters1
),
stride
=
1
,
act
=
'relu6'
,
name
=
name
+
"_extra1"
)
depthwise_conv
=
self
.
_conv_bn_layer
(
input
=
pointwise_conv
,
filter_size
=
3
,
num_filters
=
int
(
num_filters2
),
stride
=
stride
,
groups
=
int
(
num_filters1
),
act
=
'relu6'
,
name
=
name
+
"_extra2_dw"
)
normal_conv
=
self
.
_conv_bn_layer
(
input
=
depthwise_conv
,
filter_size
=
1
,
num_filters
=
int
(
num_filters2
),
stride
=
1
,
act
=
'relu6'
,
name
=
name
+
"_extra2_sep"
)
return
normal_conv
def
_make_divisible
(
self
,
v
,
divisor
=
8
,
min_value
=
None
):
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
def
__call__
(
self
,
input
):
# build first layer
output_channel
=
int
(
self
.
_make_divisible
(
16
*
self
.
scale
,
4
))
x
=
self
.
_conv_bn_layer
(
input
=
input
,
num_filters
=
output_channel
,
filter_size
=
3
,
stride
=
2
,
groups
=
1
,
act
=
"relu"
,
name
=
"conv1"
)
# build inverted residual blocks
idx
=
0
for
k
,
exp_size
,
c
,
use_se
,
s
in
self
.
cfgs
:
if
s
==
2
:
self
.
block_stride
+=
1
if
self
.
block_stride
in
self
.
feature_maps
:
self
.
end_points
.
append
(
x
)
output_channel
=
int
(
self
.
_make_divisible
(
c
*
self
.
scale
,
4
))
hidden_channel
=
int
(
self
.
_make_divisible
(
exp_size
*
self
.
scale
,
4
))
x
=
self
.
ghost_bottleneck
(
input
=
x
,
hidden_dim
=
hidden_channel
,
output
=
output_channel
,
kernel_size
=
k
,
stride
=
s
,
use_se
=
use_se
,
name
=
"_ghostbottleneck_"
+
str
(
idx
))
idx
+=
1
self
.
curr_stage
+=
1
self
.
block_stride
+=
1
if
self
.
block_stride
in
self
.
feature_maps
:
self
.
end_points
.
append
(
conv
)
# extra block
# check whether conv_extra is needed
if
self
.
block_stride
<
max
(
self
.
feature_maps
):
conv_extra
=
self
.
_conv_bn_layer
(
x
,
num_filters
=
self
.
_make_divisible
(
self
.
scale
*
self
.
cfgs
[
-
1
][
1
]),
filter_size
=
1
,
stride
=
1
,
groups
=
1
,
act
=
'relu6'
,
name
=
'conv'
+
str
(
idx
+
2
))
self
.
block_stride
+=
1
if
self
.
block_stride
in
self
.
feature_maps
:
self
.
end_points
.
append
(
conv_extra
)
idx
+=
1
for
block_filter
in
self
.
extra_block_filters
:
conv_extra
=
self
.
_extra_block_dw
(
conv_extra
,
block_filter
[
0
],
block_filter
[
1
],
2
,
'conv'
+
str
(
idx
+
2
))
self
.
block_stride
+=
1
if
self
.
block_stride
in
self
.
feature_maps
:
self
.
end_points
.
append
(
conv_extra
)
idx
+=
1
return
OrderedDict
([(
'ghost_{}'
.
format
(
idx
),
feat
)
for
idx
,
feat
in
enumerate
(
self
.
end_points
)])
return
res
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录