Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
8e135a4d
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8e135a4d
编写于
10月 21, 2020
作者:
H
haoyuying
提交者:
GitHub
10月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mobilenet series
上级
bd8e0ad0
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
980 addition
and
1550 deletion
+980
-1550
hub_module/modules/image/classification/mobilenet_v1_imagenet/__init__.py
...es/image/classification/mobilenet_v1_imagenet/__init__.py
+0
-0
hub_module/modules/image/classification/mobilenet_v1_imagenet/data_feed.py
...s/image/classification/mobilenet_v1_imagenet/data_feed.py
+0
-74
hub_module/modules/image/classification/mobilenet_v1_imagenet/label_file.txt
...image/classification/mobilenet_v1_imagenet/label_file.txt
+0
-1000
hub_module/modules/image/classification/mobilenet_v1_imagenet/mobilenet_v1.py
...mage/classification/mobilenet_v1_imagenet/mobilenet_v1.py
+0
-211
hub_module/modules/image/classification/mobilenet_v1_imagenet/module.py
...ules/image/classification/mobilenet_v1_imagenet/module.py
+236
-261
hub_module/modules/image/classification/mobilenet_v1_imagenet/processor.py
...s/image/classification/mobilenet_v1_imagenet/processor.py
+0
-4
hub_module/modules/image/classification/mobilenet_v1_imagenet_ssld/module.py
...image/classification/mobilenet_v1_imagenet_ssld/module.py
+241
-0
hub_module/modules/image/classification/mobilenet_v2_imagenet/module.py
...ules/image/classification/mobilenet_v2_imagenet/module.py
+209
-0
hub_module/modules/image/classification/shufflenet_v2_imagenet/module.py
...les/image/classification/shufflenet_v2_imagenet/module.py
+294
-0
未找到文件。
hub_module/modules/image/classification/mobilenet_v1_imagenet/__init__.py
已删除
100644 → 0
浏览文件 @
bd8e0ad0
hub_module/modules/image/classification/mobilenet_v1_imagenet/data_feed.py
已删除
100644 → 0
浏览文件 @
bd8e0ad0
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
os
from
collections
import
OrderedDict
import
numpy
as
np
import
cv2
from
PIL
import
Image
,
ImageEnhance
from
paddle
import
fluid
DATA_DIM
=
224
img_mean
=
np
.
array
([
0.485
,
0.456
,
0.406
]).
reshape
((
3
,
1
,
1
))
img_std
=
np
.
array
([
0.229
,
0.224
,
0.225
]).
reshape
((
3
,
1
,
1
))
def
resize_short
(
img
,
target_size
):
percent
=
float
(
target_size
)
/
min
(
img
.
size
[
0
],
img
.
size
[
1
])
resized_width
=
int
(
round
(
img
.
size
[
0
]
*
percent
))
resized_height
=
int
(
round
(
img
.
size
[
1
]
*
percent
))
img
=
img
.
resize
((
resized_width
,
resized_height
),
Image
.
LANCZOS
)
return
img
def
crop_image
(
img
,
target_size
,
center
):
width
,
height
=
img
.
size
size
=
target_size
if
center
==
True
:
w_start
=
(
width
-
size
)
/
2
h_start
=
(
height
-
size
)
/
2
else
:
w_start
=
np
.
random
.
randint
(
0
,
width
-
size
+
1
)
h_start
=
np
.
random
.
randint
(
0
,
height
-
size
+
1
)
w_end
=
w_start
+
size
h_end
=
h_start
+
size
img
=
img
.
crop
((
w_start
,
h_start
,
w_end
,
h_end
))
return
img
def
process_image
(
img
):
img
=
resize_short
(
img
,
target_size
=
256
)
img
=
crop_image
(
img
,
target_size
=
DATA_DIM
,
center
=
True
)
if
img
.
mode
!=
'RGB'
:
img
=
img
.
convert
(
'RGB'
)
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img
=
np
.
array
(
img
).
astype
(
'float32'
).
transpose
((
2
,
0
,
1
))
/
255
img
-=
img_mean
img
/=
img_std
return
img
def
test_reader
(
paths
=
None
,
images
=
None
):
"""data generator
:param paths: path to images.
:type paths: list, each element is a str
:param images: data of images, [N, H, W, C]
:type images: numpy.ndarray
"""
img_list
=
[]
if
paths
:
for
img_path
in
paths
:
assert
os
.
path
.
isfile
(
img_path
),
"The {} isn't a valid file path."
.
format
(
img_path
)
img
=
Image
.
open
(
img_path
)
#img = cv2.imread(img_path)
img_list
.
append
(
img
)
if
images
is
not
None
:
for
img
in
images
:
img_list
.
append
(
Image
.
fromarray
(
np
.
uint8
(
img
)))
for
im
in
img_list
:
im
=
process_image
(
im
)
yield
im
hub_module/modules/image/classification/mobilenet_v1_imagenet/label_file.txt
已删除
100644 → 0
浏览文件 @
bd8e0ad0
此差异已折叠。
点击以展开。
hub_module/modules/image/classification/mobilenet_v1_imagenet/mobilenet_v1.py
已删除
100644 → 0
浏览文件 @
bd8e0ad0
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
__all__
=
[
'MobileNet'
]
class
MobileNet
(
object
):
"""
MobileNet v1, see https://arxiv.org/abs/1704.04861
Args:
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights
conv_group_scale (int): scaling factor for convolution groups
with_extra_blocks (bool): if extra blocks should be added
extra_block_filters (list): number of filter for each extra block
class_dim (int): number of class while classification
yolo_v3 (bool): whether to output layers which yolo_v3 needs
"""
__shared__
=
[
'norm_type'
,
'weight_prefix_name'
]
def
__init__
(
self
,
norm_type
=
'bn'
,
norm_decay
=
0.
,
conv_group_scale
=
1
,
conv_learning_rate
=
1.0
,
with_extra_blocks
=
False
,
extra_block_filters
=
[[
256
,
512
],
[
128
,
256
],
[
128
,
256
],
[
64
,
128
]],
weight_prefix_name
=
''
,
class_dim
=
1000
,
yolo_v3
=
False
):
self
.
norm_type
=
norm_type
self
.
norm_decay
=
norm_decay
self
.
conv_group_scale
=
conv_group_scale
self
.
conv_learning_rate
=
conv_learning_rate
self
.
with_extra_blocks
=
with_extra_blocks
self
.
extra_block_filters
=
extra_block_filters
self
.
prefix_name
=
weight_prefix_name
self
.
class_dim
=
class_dim
self
.
yolo_v3
=
yolo_v3
def
_conv_norm
(
self
,
input
,
filter_size
,
num_filters
,
stride
,
padding
,
num_groups
=
1
,
act
=
'relu'
,
use_cudnn
=
True
,
name
=
None
):
parameter_attr
=
ParamAttr
(
learning_rate
=
self
.
conv_learning_rate
,
initializer
=
fluid
.
initializer
.
MSRA
(),
name
=
name
+
"_weights"
)
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
num_groups
,
act
=
None
,
use_cudnn
=
use_cudnn
,
param_attr
=
parameter_attr
,
bias_attr
=
False
)
bn_name
=
name
+
"_bn"
norm_decay
=
self
.
norm_decay
bn_param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
norm_decay
),
name
=
bn_name
+
'_scale'
)
bn_bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
norm_decay
),
name
=
bn_name
+
'_offset'
)
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
def
depthwise_separable
(
self
,
input
,
num_filters1
,
num_filters2
,
num_groups
,
stride
,
scale
,
name
=
None
):
depthwise_conv
=
self
.
_conv_norm
(
input
=
input
,
filter_size
=
3
,
num_filters
=
int
(
num_filters1
*
scale
),
stride
=
stride
,
padding
=
1
,
num_groups
=
int
(
num_groups
*
scale
),
use_cudnn
=
False
,
name
=
name
+
"_dw"
)
pointwise_conv
=
self
.
_conv_norm
(
input
=
depthwise_conv
,
filter_size
=
1
,
num_filters
=
int
(
num_filters2
*
scale
),
stride
=
1
,
padding
=
0
,
name
=
name
+
"_sep"
)
return
pointwise_conv
def
_extra_block
(
self
,
input
,
num_filters1
,
num_filters2
,
num_groups
,
stride
,
name
=
None
):
pointwise_conv
=
self
.
_conv_norm
(
input
=
input
,
filter_size
=
1
,
num_filters
=
int
(
num_filters1
),
stride
=
1
,
num_groups
=
int
(
num_groups
),
padding
=
0
,
name
=
name
+
"_extra1"
)
normal_conv
=
self
.
_conv_norm
(
input
=
pointwise_conv
,
filter_size
=
3
,
num_filters
=
int
(
num_filters2
),
stride
=
2
,
num_groups
=
int
(
num_groups
),
padding
=
1
,
name
=
name
+
"_extra2"
)
return
normal_conv
def
__call__
(
self
,
input
):
scale
=
self
.
conv_group_scale
blocks
=
[]
# input 1/1
out
=
self
.
_conv_norm
(
input
,
3
,
int
(
32
*
scale
),
2
,
1
,
name
=
self
.
prefix_name
+
"conv1"
)
# 1/2
out
=
self
.
depthwise_separable
(
out
,
32
,
64
,
32
,
1
,
scale
,
name
=
self
.
prefix_name
+
"conv2_1"
)
out
=
self
.
depthwise_separable
(
out
,
64
,
128
,
64
,
2
,
scale
,
name
=
self
.
prefix_name
+
"conv2_2"
)
# 1/4
out
=
self
.
depthwise_separable
(
out
,
128
,
128
,
128
,
1
,
scale
,
name
=
self
.
prefix_name
+
"conv3_1"
)
out
=
self
.
depthwise_separable
(
out
,
128
,
256
,
128
,
2
,
scale
,
name
=
self
.
prefix_name
+
"conv3_2"
)
# 1/8
blocks
.
append
(
out
)
out
=
self
.
depthwise_separable
(
out
,
256
,
256
,
256
,
1
,
scale
,
name
=
self
.
prefix_name
+
"conv4_1"
)
out
=
self
.
depthwise_separable
(
out
,
256
,
512
,
256
,
2
,
scale
,
name
=
self
.
prefix_name
+
"conv4_2"
)
# 1/16
blocks
.
append
(
out
)
for
i
in
range
(
5
):
out
=
self
.
depthwise_separable
(
out
,
512
,
512
,
512
,
1
,
scale
,
name
=
self
.
prefix_name
+
"conv5_"
+
str
(
i
+
1
))
module11
=
out
out
=
self
.
depthwise_separable
(
out
,
512
,
1024
,
512
,
2
,
scale
,
name
=
self
.
prefix_name
+
"conv5_6"
)
# 1/32
out
=
self
.
depthwise_separable
(
out
,
1024
,
1024
,
1024
,
1
,
scale
,
name
=
self
.
prefix_name
+
"conv6"
)
module13
=
out
blocks
.
append
(
out
)
if
self
.
yolo_v3
:
return
blocks
if
not
self
.
with_extra_blocks
:
out
=
fluid
.
layers
.
pool2d
(
input
=
out
,
pool_type
=
'avg'
,
global_pooling
=
True
)
out
=
fluid
.
layers
.
fc
(
input
=
out
,
size
=
self
.
class_dim
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
MSRA
(),
name
=
"fc7_weights"
),
bias_attr
=
ParamAttr
(
name
=
"fc7_offset"
))
out
=
fluid
.
layers
.
softmax
(
out
)
blocks
.
append
(
out
)
return
blocks
num_filters
=
self
.
extra_block_filters
module14
=
self
.
_extra_block
(
module13
,
num_filters
[
0
][
0
],
num_filters
[
0
][
1
],
1
,
2
,
self
.
prefix_name
+
"conv7_1"
)
module15
=
self
.
_extra_block
(
module14
,
num_filters
[
1
][
0
],
num_filters
[
1
][
1
],
1
,
2
,
self
.
prefix_name
+
"conv7_2"
)
module16
=
self
.
_extra_block
(
module15
,
num_filters
[
2
][
0
],
num_filters
[
2
][
1
],
1
,
2
,
self
.
prefix_name
+
"conv7_3"
)
module17
=
self
.
_extra_block
(
module16
,
num_filters
[
3
][
0
],
num_filters
[
3
][
1
],
1
,
2
,
self
.
prefix_name
+
"conv7_4"
)
return
module11
,
module13
,
module14
,
module15
,
module16
,
module17
hub_module/modules/image/classification/mobilenet_v1_imagenet/module.py
浏览文件 @
8e135a4d
此差异已折叠。
点击以展开。
hub_module/modules/image/classification/mobilenet_v1_imagenet/processor.py
已删除
100644 → 0
浏览文件 @
bd8e0ad0
# coding=utf-8
def
load_label_info
(
file_path
):
with
open
(
file_path
,
'r'
)
as
fr
:
return
fr
.
read
().
split
(
"
\n
"
)[:
-
1
]
hub_module/modules/image/classification/mobilenet_v1_imagenet_ssld/module.py
0 → 100644
浏览文件 @
8e135a4d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
math
import
numpy
as
np
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2d
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2d
,
MaxPool2d
,
AvgPool2d
from
paddle.nn.initializer
import
MSRA
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.cv_module
import
ImageClassifierModule
class
ConvBNLayer
(
nn
.
Layer
):
"""Basic conv bn layer."""
def
__init__
(
self
,
num_channels
:
int
,
filter_size
:
int
,
num_filters
:
int
,
stride
:
int
,
padding
:
int
,
channels
:
int
=
None
,
num_groups
:
int
=
1
,
act
:
str
=
'relu'
,
name
:
str
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_conv
=
Conv2d
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
num_groups
,
weight_attr
=
ParamAttr
(
initializer
=
MSRA
(),
name
=
name
+
"_weights"
),
bias_attr
=
False
)
self
.
_batch_norm
=
BatchNorm
(
num_filters
,
act
=
act
,
param_attr
=
ParamAttr
(
name
+
"_bn_scale"
),
bias_attr
=
ParamAttr
(
name
+
"_bn_offset"
),
moving_mean_name
=
name
+
"_bn_mean"
,
moving_variance_name
=
name
+
"_bn_variance"
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
DepthwiseSeparable
(
nn
.
Layer
):
"""Depthwise and pointwise conv layer."""
def
__init__
(
self
,
num_channels
:
int
,
num_filters1
:
int
,
num_filters2
:
int
,
num_groups
:
int
,
stride
:
int
,
scale
:
float
,
name
:
str
=
None
):
super
(
DepthwiseSeparable
,
self
).
__init__
()
self
.
_depthwise_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
int
(
num_filters1
*
scale
),
filter_size
=
3
,
stride
=
stride
,
padding
=
1
,
num_groups
=
int
(
num_groups
*
scale
),
name
=
name
+
"_dw"
)
self
.
_pointwise_conv
=
ConvBNLayer
(
num_channels
=
int
(
num_filters1
*
scale
),
filter_size
=
1
,
num_filters
=
int
(
num_filters2
*
scale
),
stride
=
1
,
padding
=
0
,
name
=
name
+
"_sep"
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
_depthwise_conv
(
inputs
)
y
=
self
.
_pointwise_conv
(
y
)
return
y
@
moduleinfo
(
name
=
"mobilenet_v1_imagenet_ssld"
,
type
=
"cv/classification"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"mobilenet_v1_imagenet_ssld is a classification model, "
"this module is trained with Imagenet dataset."
,
version
=
"1.1.0"
,
meta
=
ImageClassifierModule
)
class
MobileNetV1
(
nn
.
Layer
):
"""MobileNetV1"""
def
__init__
(
self
,
class_dim
:
int
=
1000
,
load_checkpoint
:
str
=
None
):
super
(
MobileNetV1
,
self
).
__init__
()
self
.
block_list
=
[]
self
.
conv1
=
ConvBNLayer
(
num_channels
=
3
,
filter_size
=
3
,
channels
=
3
,
num_filters
=
int
(
32
),
stride
=
2
,
padding
=
1
,
name
=
"conv1"
)
conv2_1
=
self
.
add_sublayer
(
"conv2_1"
,
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
32
),
num_filters1
=
32
,
num_filters2
=
64
,
num_groups
=
32
,
stride
=
1
,
scale
=
1
,
name
=
"conv2_1"
))
self
.
block_list
.
append
(
conv2_1
)
conv2_2
=
self
.
add_sublayer
(
"conv2_2"
,
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
64
),
num_filters1
=
64
,
num_filters2
=
128
,
num_groups
=
64
,
stride
=
2
,
scale
=
1
,
name
=
"conv2_2"
))
self
.
block_list
.
append
(
conv2_2
)
conv3_1
=
self
.
add_sublayer
(
"conv3_1"
,
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
128
),
num_filters1
=
128
,
num_filters2
=
128
,
num_groups
=
128
,
stride
=
1
,
scale
=
1
,
name
=
"conv3_1"
))
self
.
block_list
.
append
(
conv3_1
)
conv3_2
=
self
.
add_sublayer
(
"conv3_2"
,
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
128
),
num_filters1
=
128
,
num_filters2
=
256
,
num_groups
=
128
,
stride
=
2
,
scale
=
1
,
name
=
"conv3_2"
))
self
.
block_list
.
append
(
conv3_2
)
conv4_1
=
self
.
add_sublayer
(
"conv4_1"
,
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
256
),
num_filters1
=
256
,
num_filters2
=
256
,
num_groups
=
256
,
stride
=
1
,
scale
=
1
,
name
=
"conv4_1"
))
self
.
block_list
.
append
(
conv4_1
)
conv4_2
=
self
.
add_sublayer
(
"conv4_2"
,
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
256
),
num_filters1
=
256
,
num_filters2
=
512
,
num_groups
=
256
,
stride
=
2
,
scale
=
1
,
name
=
"conv4_2"
))
self
.
block_list
.
append
(
conv4_2
)
for
i
in
range
(
5
):
conv5
=
self
.
add_sublayer
(
"conv5_"
+
str
(
i
+
1
),
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
512
),
num_filters1
=
512
,
num_filters2
=
512
,
num_groups
=
512
,
stride
=
1
,
scale
=
1
,
name
=
"conv5_"
+
str
(
i
+
1
)))
self
.
block_list
.
append
(
conv5
)
conv5_6
=
self
.
add_sublayer
(
"conv5_6"
,
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
512
),
num_filters1
=
512
,
num_filters2
=
1024
,
num_groups
=
512
,
stride
=
2
,
scale
=
1
,
name
=
"conv5_6"
))
self
.
block_list
.
append
(
conv5_6
)
conv6
=
self
.
add_sublayer
(
"conv6"
,
sublayer
=
DepthwiseSeparable
(
num_channels
=
int
(
1024
),
num_filters1
=
1024
,
num_filters2
=
1024
,
num_groups
=
1024
,
stride
=
1
,
scale
=
1
,
name
=
"conv6"
))
self
.
block_list
.
append
(
conv6
)
self
.
pool2d_avg
=
AdaptiveAvgPool2d
(
1
)
self
.
out
=
Linear
(
int
(
1024
),
class_dim
,
weight_attr
=
ParamAttr
(
initializer
=
MSRA
(),
name
=
"fc7_weights"
),
bias_attr
=
ParamAttr
(
name
=
"fc7_offset"
))
if
load_checkpoint
is
not
None
:
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'mobilenet_v1_ssld_imagenet.pdparams'
)
if
not
os
.
path
.
exists
(
checkpoint
):
os
.
system
(
'wget https://paddlehub.bj.bcebos.com/dygraph/image_classification/mobilenet_v1_ssld_imagenet.pdparams -O '
+
checkpoint
)
model_dict
=
paddle
.
load
(
checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
conv1
(
inputs
)
for
block
in
self
.
block_list
:
y
=
block
(
y
)
y
=
self
.
pool2d_avg
(
y
)
y
=
paddle
.
reshape
(
y
,
shape
=
[
-
1
,
1024
])
y
=
self
.
out
(
y
)
return
y
hub_module/modules/image/classification/mobilenet_v2_imagenet/module.py
0 → 100644
浏览文件 @
8e135a4d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2d
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2d
,
MaxPool2d
,
AvgPool2d
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.cv_module
import
ImageClassifierModule
class
ConvBNLayer
(
nn
.
Layer
):
"""Basic conv bn layer."""
def
__init__
(
self
,
num_channels
:
int
,
filter_size
:
int
,
num_filters
:
int
,
stride
:
int
,
padding
:
int
,
num_groups
:
int
=
1
,
name
:
str
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_conv
=
Conv2d
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
num_groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
self
.
_batch_norm
=
BatchNorm
(
num_filters
,
param_attr
=
ParamAttr
(
name
=
name
+
"_bn_scale"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_bn_offset"
),
moving_mean_name
=
name
+
"_bn_mean"
,
moving_variance_name
=
name
+
"_bn_variance"
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
,
if_act
:
bool
=
True
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
if
if_act
:
y
=
F
.
relu6
(
y
)
return
y
class
InvertedResidualUnit
(
nn
.
Layer
):
"""Inverted Residual unit."""
def
__init__
(
self
,
num_channels
:
int
,
num_in_filter
:
int
,
num_filters
:
int
,
stride
:
int
,
filter_size
:
int
,
padding
:
int
,
expansion_factor
:
int
,
name
:
str
):
super
(
InvertedResidualUnit
,
self
).
__init__
()
num_expfilter
=
int
(
round
(
num_in_filter
*
expansion_factor
))
self
.
_expand_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
num_expfilter
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
name
=
name
+
"_expand"
)
self
.
_bottleneck_conv
=
ConvBNLayer
(
num_channels
=
num_expfilter
,
num_filters
=
num_expfilter
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
num_groups
=
num_expfilter
,
name
=
name
+
"_dwise"
)
self
.
_linear_conv
=
ConvBNLayer
(
num_channels
=
num_expfilter
,
num_filters
=
num_filters
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
name
=
name
+
"_linear"
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
,
ifshortcut
:
bool
):
y
=
self
.
_expand_conv
(
inputs
,
if_act
=
True
)
y
=
self
.
_bottleneck_conv
(
y
,
if_act
=
True
)
y
=
self
.
_linear_conv
(
y
,
if_act
=
False
)
if
ifshortcut
:
y
=
paddle
.
elementwise_add
(
inputs
,
y
)
return
y
class
InversiBlocks
(
nn
.
Layer
):
"""Inverted residual block composed by inverted residual unit."""
def
__init__
(
self
,
in_c
:
int
,
t
:
int
,
c
:
int
,
n
:
int
,
s
:
int
,
name
:
str
):
super
(
InversiBlocks
,
self
).
__init__
()
self
.
_first_block
=
InvertedResidualUnit
(
num_channels
=
in_c
,
num_in_filter
=
in_c
,
num_filters
=
c
,
stride
=
s
,
filter_size
=
3
,
padding
=
1
,
expansion_factor
=
t
,
name
=
name
+
"_1"
)
self
.
_block_list
=
[]
for
i
in
range
(
1
,
n
):
block
=
self
.
add_sublayer
(
name
+
"_"
+
str
(
i
+
1
),
sublayer
=
InvertedResidualUnit
(
num_channels
=
c
,
num_in_filter
=
c
,
num_filters
=
c
,
stride
=
1
,
filter_size
=
3
,
padding
=
1
,
expansion_factor
=
t
,
name
=
name
+
"_"
+
str
(
i
+
1
)))
self
.
_block_list
.
append
(
block
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
_first_block
(
inputs
,
ifshortcut
=
False
)
for
block
in
self
.
_block_list
:
y
=
block
(
y
,
ifshortcut
=
True
)
return
y
@
moduleinfo
(
name
=
"mobilenet_v2_imagenet"
,
type
=
"cv/classification"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"mobilenet_v2_imagenet is a classification model, "
"this module is trained with Imagenet dataset."
,
version
=
"1.1.0"
,
meta
=
ImageClassifierModule
)
class
MobileNet
(
nn
.
Layer
):
"""MobileNetV2"""
def
__init__
(
self
,
class_dim
:
int
=
1000
,
load_checkpoint
:
str
=
None
):
super
(
MobileNet
,
self
).
__init__
()
self
.
class_dim
=
class_dim
bottleneck_params_list
=
[(
1
,
16
,
1
,
1
),
(
6
,
24
,
2
,
2
),
(
6
,
32
,
3
,
2
),
(
6
,
64
,
4
,
2
),
(
6
,
96
,
3
,
1
),
(
6
,
160
,
3
,
2
),
(
6
,
320
,
1
,
1
)]
self
.
conv1
=
ConvBNLayer
(
num_channels
=
3
,
num_filters
=
int
(
32
),
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
name
=
"conv1_1"
)
self
.
block_list
=
[]
i
=
1
in_c
=
int
(
32
)
for
layer_setting
in
bottleneck_params_list
:
t
,
c
,
n
,
s
=
layer_setting
i
+=
1
block
=
self
.
add_sublayer
(
"conv"
+
str
(
i
),
sublayer
=
InversiBlocks
(
in_c
=
in_c
,
t
=
t
,
c
=
int
(
c
),
n
=
n
,
s
=
s
,
name
=
"conv"
+
str
(
i
)))
self
.
block_list
.
append
(
block
)
in_c
=
int
(
c
)
self
.
out_c
=
1280
self
.
conv9
=
ConvBNLayer
(
num_channels
=
in_c
,
num_filters
=
self
.
out_c
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
"conv9"
)
self
.
pool2d_avg
=
AdaptiveAvgPool2d
(
1
)
self
.
out
=
Linear
(
self
.
out_c
,
class_dim
,
weight_attr
=
ParamAttr
(
name
=
"fc10_weights"
),
bias_attr
=
ParamAttr
(
name
=
"fc10_offset"
))
if
load_checkpoint
is
not
None
:
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'mobilenet_v2_imagenet.pdparams'
)
if
not
os
.
path
.
exists
(
checkpoint
):
os
.
system
(
'wget https://paddlehub.bj.bcebos.com/dygraph/image_classification/mobilenet_v2_imagenet.pdparams -O '
+
checkpoint
)
model_dict
=
paddle
.
load
(
checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
conv1
(
inputs
,
if_act
=
True
)
for
block
in
self
.
block_list
:
y
=
block
(
y
)
y
=
self
.
conv9
(
y
,
if_act
=
True
)
y
=
self
.
pool2d_avg
(
y
)
y
=
paddle
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
out_c
])
y
=
self
.
out
(
y
)
return
y
hub_module/modules/image/classification/shufflenet_v2_imagenet/module.py
0 → 100644
浏览文件 @
8e135a4d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2d
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2d
,
MaxPool2d
,
AvgPool2d
from
paddle.nn.initializer
import
MSRA
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.cv_module
import
ImageClassifierModule
def
channel_shuffle
(
x
:
paddle
.
Tensor
,
groups
:
int
):
"""Shuffle input channels."""
batchsize
,
num_channels
,
height
,
width
=
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
],
x
.
shape
[
3
]
channels_per_group
=
num_channels
//
groups
# reshape
x
=
paddle
.
reshape
(
x
=
x
,
shape
=
[
batchsize
,
groups
,
channels_per_group
,
height
,
width
])
x
=
paddle
.
transpose
(
x
=
x
,
perm
=
[
0
,
2
,
1
,
3
,
4
])
# flatten
x
=
paddle
.
reshape
(
x
=
x
,
shape
=
[
batchsize
,
num_channels
,
height
,
width
])
return
x
class
ConvBNLayer
(
nn
.
Layer
):
"""Basic conv bn layer."""
def
__init__
(
self
,
num_channels
:
int
,
filter_size
:
int
,
num_filters
:
int
,
stride
:
int
,
padding
:
int
,
channels
:
int
=
None
,
num_groups
:
int
=
1
,
if_act
:
bool
=
True
,
act
:
str
=
'relu'
,
name
:
str
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_if_act
=
if_act
assert
act
in
[
'relu'
,
'swish'
],
\
"supported act are {} but your act is {}"
.
format
(
[
'relu'
,
'swish'
],
act
)
self
.
_act
=
act
self
.
_conv
=
Conv2d
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
num_groups
,
weight_attr
=
ParamAttr
(
initializer
=
MSRA
(),
name
=
name
+
"_weights"
),
bias_attr
=
False
)
self
.
_batch_norm
=
BatchNorm
(
num_filters
,
param_attr
=
ParamAttr
(
name
=
name
+
"_bn_scale"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_bn_offset"
),
moving_mean_name
=
name
+
"_bn_mean"
,
moving_variance_name
=
name
+
"_bn_variance"
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
,
if_act
:
bool
=
True
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
if
self
.
_if_act
:
y
=
F
.
relu
(
y
)
if
self
.
_act
==
'relu'
else
F
.
swish
(
y
)
return
y
class
InvertedResidualUnit
(
nn
.
Layer
):
"""Inverted Residual unit."""
def
__init__
(
self
,
num_channels
:
int
,
num_filters
:
int
,
stride
:
int
,
benchmodel
:
int
,
act
:
str
=
'relu'
,
name
:
str
=
None
):
super
(
InvertedResidualUnit
,
self
).
__init__
()
assert
stride
in
[
1
,
2
],
\
"supported stride are {} but your stride is {}"
.
format
([
1
,
2
],
stride
)
self
.
benchmodel
=
benchmodel
oup_inc
=
num_filters
//
2
inp
=
num_channels
if
benchmodel
==
1
:
self
.
_conv_pw
=
ConvBNLayer
(
num_channels
=
num_channels
//
2
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv1'
)
self
.
_conv_dw
=
ConvBNLayer
(
num_channels
=
oup_inc
,
num_filters
=
oup_inc
,
filter_size
=
3
,
stride
=
stride
,
padding
=
1
,
num_groups
=
oup_inc
,
if_act
=
False
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv2'
)
self
.
_conv_linear
=
ConvBNLayer
(
num_channels
=
oup_inc
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv3'
)
else
:
# branch1
self
.
_conv_dw_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
inp
,
filter_size
=
3
,
stride
=
stride
,
padding
=
1
,
num_groups
=
inp
,
if_act
=
False
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv4'
)
self
.
_conv_linear_1
=
ConvBNLayer
(
num_channels
=
inp
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv5'
)
# branch2
self
.
_conv_pw_2
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv1'
)
self
.
_conv_dw_2
=
ConvBNLayer
(
num_channels
=
oup_inc
,
num_filters
=
oup_inc
,
filter_size
=
3
,
stride
=
stride
,
padding
=
1
,
num_groups
=
oup_inc
,
if_act
=
False
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv2'
)
self
.
_conv_linear_2
=
ConvBNLayer
(
num_channels
=
oup_inc
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv3'
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
if
self
.
benchmodel
==
1
:
x1
,
x2
=
paddle
.
split
(
inputs
,
num_or_sections
=
[
inputs
.
shape
[
1
]
//
2
,
inputs
.
shape
[
1
]
//
2
],
axis
=
1
)
x2
=
self
.
_conv_pw
(
x2
)
x2
=
self
.
_conv_dw
(
x2
)
x2
=
self
.
_conv_linear
(
x2
)
out
=
paddle
.
concat
([
x1
,
x2
],
axis
=
1
)
else
:
x1
=
self
.
_conv_dw_1
(
inputs
)
x1
=
self
.
_conv_linear_1
(
x1
)
x2
=
self
.
_conv_pw_2
(
inputs
)
x2
=
self
.
_conv_dw_2
(
x2
)
x2
=
self
.
_conv_linear_2
(
x2
)
out
=
paddle
.
concat
([
x1
,
x2
],
axis
=
1
)
return
channel_shuffle
(
out
,
2
)
@
moduleinfo
(
name
=
"shufflenet_v2_imagenet"
,
type
=
"cv/classification"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"shufflenet_v2_imagenet is a classification model, "
"this module is trained with Imagenet dataset."
,
version
=
"1.1.0"
,
meta
=
ImageClassifierModule
)
class
ShuffleNet
(
nn
.
Layer
):
"""ShuffleNet model."""
def
__init__
(
self
,
class_dim
:
int
=
1000
,
load_checkpoint
:
str
=
None
):
super
(
ShuffleNet
,
self
).
__init__
()
self
.
scale
=
1
self
.
class_dim
=
class_dim
stage_repeats
=
[
4
,
8
,
4
]
stage_out_channels
=
[
-
1
,
24
,
116
,
232
,
464
,
1024
]
# 1. conv1
self
.
_conv1
=
ConvBNLayer
(
num_channels
=
3
,
num_filters
=
stage_out_channels
[
1
],
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
if_act
=
True
,
act
=
'relu'
,
name
=
'stage1_conv'
)
self
.
_max_pool
=
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
# 2. bottleneck sequences
self
.
_block_list
=
[]
i
=
1
in_c
=
int
(
32
)
for
idxstage
in
range
(
len
(
stage_repeats
)):
numrepeat
=
stage_repeats
[
idxstage
]
output_channel
=
stage_out_channels
[
idxstage
+
2
]
for
i
in
range
(
numrepeat
):
if
i
==
0
:
block
=
self
.
add_sublayer
(
str
(
idxstage
+
2
)
+
'_'
+
str
(
i
+
1
),
InvertedResidualUnit
(
num_channels
=
stage_out_channels
[
idxstage
+
1
],
num_filters
=
output_channel
,
stride
=
2
,
benchmodel
=
2
,
act
=
'relu'
,
name
=
str
(
idxstage
+
2
)
+
'_'
+
str
(
i
+
1
)))
self
.
_block_list
.
append
(
block
)
else
:
block
=
self
.
add_sublayer
(
str
(
idxstage
+
2
)
+
'_'
+
str
(
i
+
1
),
InvertedResidualUnit
(
num_channels
=
output_channel
,
num_filters
=
output_channel
,
stride
=
1
,
benchmodel
=
1
,
act
=
'relu'
,
name
=
str
(
idxstage
+
2
)
+
'_'
+
str
(
i
+
1
)))
self
.
_block_list
.
append
(
block
)
# 3. last_conv
self
.
_last_conv
=
ConvBNLayer
(
num_channels
=
stage_out_channels
[
-
2
],
num_filters
=
stage_out_channels
[
-
1
],
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
if_act
=
True
,
act
=
'relu'
,
name
=
'conv5'
)
# 4. pool
self
.
_pool2d_avg
=
AdaptiveAvgPool2d
(
1
)
self
.
_out_c
=
stage_out_channels
[
-
1
]
# 5. fc
self
.
_fc
=
Linear
(
stage_out_channels
[
-
1
],
class_dim
,
weight_attr
=
ParamAttr
(
name
=
'fc6_weights'
),
bias_attr
=
ParamAttr
(
name
=
'fc6_offset'
))
if
load_checkpoint
is
not
None
:
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'shufflenet_v2_imagenet.pdparams'
)
if
not
os
.
path
.
exists
(
checkpoint
):
os
.
system
(
'wget https://paddlehub.bj.bcebos.com/dygraph/image_classification/shufflenet_v2_imagenet.pdparams -O '
+
checkpoint
)
model_dict
=
paddle
.
load
(
checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
y
=
self
.
_conv1
(
inputs
)
y
=
self
.
_max_pool
(
y
)
for
inv
in
self
.
_block_list
:
y
=
inv
(
y
)
y
=
self
.
_last_conv
(
y
)
y
=
self
.
_pool2d_avg
(
y
)
y
=
paddle
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
_out_c
])
y
=
self
.
_fc
(
y
)
return
y
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录