Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
68af310b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
68af310b
编写于
3月 09, 2022
作者:
N
Nyakku Shigure
提交者:
GitHub
3月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add MobileNetV3 (#38653)
* add mobilenetv3
上级
767647ce
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
554 addition
and
15 deletion
+554
-15
python/paddle/tests/test_pretrained_model.py
python/paddle/tests/test_pretrained_model.py
+2
-0
python/paddle/tests/test_vision_models.py
python/paddle/tests/test_vision_models.py
+6
-0
python/paddle/vision/__init__.py
python/paddle/vision/__init__.py
+4
-0
python/paddle/vision/models/__init__.py
python/paddle/vision/models/__init__.py
+8
-0
python/paddle/vision/models/mobilenetv2.py
python/paddle/vision/models/mobilenetv2.py
+2
-14
python/paddle/vision/models/mobilenetv3.py
python/paddle/vision/models/mobilenetv3.py
+445
-0
python/paddle/vision/models/utils.py
python/paddle/vision/models/utils.py
+32
-0
python/paddle/vision/ops.py
python/paddle/vision/ops.py
+55
-1
未找到文件。
python/paddle/tests/test_pretrained_model.py
浏览文件 @
68af310b
...
...
@@ -61,6 +61,8 @@ class TestPretrainedModel(unittest.TestCase):
arches
=
[
'mobilenet_v1'
,
'mobilenet_v2'
,
'mobilenet_v3_small'
,
'mobilenet_v3_large'
,
'squeezenet1_0'
,
'shufflenet_v2_x0_25'
,
]
...
...
python/paddle/tests/test_vision_models.py
浏览文件 @
68af310b
...
...
@@ -40,6 +40,12 @@ class TestVisonModels(unittest.TestCase):
def
test_mobilenetv1
(
self
):
self
.
models_infer
(
'mobilenet_v1'
)
def
test_mobilenetv3_small
(
self
):
self
.
models_infer
(
'mobilenet_v3_small'
)
def
test_mobilenetv3_large
(
self
):
self
.
models_infer
(
'mobilenet_v3_large'
)
def
test_vgg11
(
self
):
self
.
models_infer
(
'vgg11'
)
...
...
python/paddle/vision/__init__.py
浏览文件 @
68af310b
...
...
@@ -40,6 +40,10 @@ from .models import MobileNetV1 # noqa: F401
from
.models
import
mobilenet_v1
# noqa: F401
from
.models
import
MobileNetV2
# noqa: F401
from
.models
import
mobilenet_v2
# noqa: F401
from
.models
import
MobileNetV3Small
# noqa: F401
from
.models
import
MobileNetV3Large
# noqa: F401
from
.models
import
mobilenet_v3_small
# noqa: F401
from
.models
import
mobilenet_v3_large
# noqa: F401
from
.models
import
SqueezeNet
# noqa: F401
from
.models
import
squeezenet1_0
# noqa: F401
from
.models
import
squeezenet1_1
# noqa: F401
...
...
python/paddle/vision/models/__init__.py
浏览文件 @
68af310b
...
...
@@ -24,6 +24,10 @@ from .mobilenetv1 import MobileNetV1 # noqa: F401
from
.mobilenetv1
import
mobilenet_v1
# noqa: F401
from
.mobilenetv2
import
MobileNetV2
# noqa: F401
from
.mobilenetv2
import
mobilenet_v2
# noqa: F401
from
.mobilenetv3
import
MobileNetV3Small
# noqa: F401
from
.mobilenetv3
import
MobileNetV3Large
# noqa: F401
from
.mobilenetv3
import
mobilenet_v3_small
# noqa: F401
from
.mobilenetv3
import
mobilenet_v3_large
# noqa: F401
from
.vgg
import
VGG
# noqa: F401
from
.vgg
import
vgg11
# noqa: F401
from
.vgg
import
vgg13
# noqa: F401
...
...
@@ -79,6 +83,10 @@ __all__ = [ #noqa
'mobilenet_v1'
,
'MobileNetV2'
,
'mobilenet_v2'
,
'MobileNetV3Small'
,
'MobileNetV3Large'
,
'mobilenet_v3_small'
,
'mobilenet_v3_large'
,
'LeNet'
,
'DenseNet'
,
'densenet121'
,
...
...
python/paddle/vision/models/mobilenetv2.py
浏览文件 @
68af310b
...
...
@@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.utils.download
import
get_weights_path_from_url
from
.utils
import
_make_divisible
__all__
=
[]
model_urls
=
{
...
...
@@ -29,16 +27,6 @@ model_urls = {
}
def
_make_divisible
(
v
,
divisor
,
min_value
=
None
):
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
class
ConvBNReLU
(
nn
.
Sequential
):
def
__init__
(
self
,
in_planes
,
...
...
python/paddle/vision/models/mobilenetv3.py
0 → 100644
浏览文件 @
68af310b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
from
paddle.utils.download
import
get_weights_path_from_url
from
functools
import
partial
from
.utils
import
_make_divisible
from
..ops
import
ConvNormActivation
__all__
=
[]
model_urls
=
{
"mobilenet_v3_small_x1.0"
:
(
"https://paddle-hapi.bj.bcebos.com/models/mobilenet_v3_small_x1.0.pdparams"
,
"34fe0e7c1f8b00b2b056ad6788d0590c"
),
"mobilenet_v3_large_x1.0"
:
(
"https://paddle-hapi.bj.bcebos.com/models/mobilenet_v3_large_x1.0.pdparams"
,
"118db5792b4e183b925d8e8e334db3df"
),
}
class
SqueezeExcitation
(
nn
.
Layer
):
"""
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3.
This code is based on the torchvision code with modifications.
You can also see at https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L127
Args:
input_channels (int): Number of channels in the input image
squeeze_channels (int): Number of squeeze channels
activation (Callable[..., paddle.nn.Layer], optional): ``delta`` activation. Default: ``paddle.nn.ReLU``
scale_activation (Callable[..., paddle.nn.Layer]): ``sigma`` activation. Default: ``paddle.nn.Sigmoid``
"""
def
__init__
(
self
,
input_channels
,
squeeze_channels
,
activation
=
nn
.
ReLU
,
scale_activation
=
nn
.
Sigmoid
):
super
().
__init__
()
self
.
avgpool
=
nn
.
AdaptiveAvgPool2D
(
1
)
self
.
fc1
=
nn
.
Conv2D
(
input_channels
,
squeeze_channels
,
1
)
self
.
fc2
=
nn
.
Conv2D
(
squeeze_channels
,
input_channels
,
1
)
self
.
activation
=
activation
()
self
.
scale_activation
=
scale_activation
()
def
_scale
(
self
,
input
):
scale
=
self
.
avgpool
(
input
)
scale
=
self
.
fc1
(
scale
)
scale
=
self
.
activation
(
scale
)
scale
=
self
.
fc2
(
scale
)
return
self
.
scale_activation
(
scale
)
def
forward
(
self
,
input
):
scale
=
self
.
_scale
(
input
)
return
scale
*
input
class
InvertedResidualConfig
:
def
__init__
(
self
,
in_channels
,
kernel
,
expanded_channels
,
out_channels
,
use_se
,
activation
,
stride
,
scale
=
1.0
):
self
.
in_channels
=
self
.
adjust_channels
(
in_channels
,
scale
=
scale
)
self
.
kernel
=
kernel
self
.
expanded_channels
=
self
.
adjust_channels
(
expanded_channels
,
scale
=
scale
)
self
.
out_channels
=
self
.
adjust_channels
(
out_channels
,
scale
=
scale
)
self
.
use_se
=
use_se
if
activation
is
None
:
self
.
activation_layer
=
None
elif
activation
==
"relu"
:
self
.
activation_layer
=
nn
.
ReLU
elif
activation
==
"hardswish"
:
self
.
activation_layer
=
nn
.
Hardswish
else
:
raise
RuntimeError
(
"The activation function is not supported: {}"
.
format
(
activation
))
self
.
stride
=
stride
@
staticmethod
def
adjust_channels
(
channels
,
scale
=
1.0
):
return
_make_divisible
(
channels
*
scale
,
8
)
class
InvertedResidual
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
expanded_channels
,
out_channels
,
filter_size
,
stride
,
use_se
,
activation_layer
,
norm_layer
):
super
().
__init__
()
self
.
use_res_connect
=
stride
==
1
and
in_channels
==
out_channels
self
.
use_se
=
use_se
self
.
expand
=
in_channels
!=
expanded_channels
if
self
.
expand
:
self
.
expand_conv
=
ConvNormActivation
(
in_channels
=
in_channels
,
out_channels
=
expanded_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
norm_layer
=
norm_layer
,
activation_layer
=
activation_layer
)
self
.
bottleneck_conv
=
ConvNormActivation
(
in_channels
=
expanded_channels
,
out_channels
=
expanded_channels
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
int
((
filter_size
-
1
)
//
2
),
groups
=
expanded_channels
,
norm_layer
=
norm_layer
,
activation_layer
=
activation_layer
)
if
self
.
use_se
:
self
.
mid_se
=
SqueezeExcitation
(
expanded_channels
,
_make_divisible
(
expanded_channels
//
4
),
scale_activation
=
nn
.
Hardsigmoid
)
self
.
linear_conv
=
ConvNormActivation
(
in_channels
=
expanded_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
norm_layer
=
norm_layer
,
activation_layer
=
None
)
def
forward
(
self
,
x
):
identity
=
x
if
self
.
expand
:
x
=
self
.
expand_conv
(
x
)
x
=
self
.
bottleneck_conv
(
x
)
if
self
.
use_se
:
x
=
self
.
mid_se
(
x
)
x
=
self
.
linear_conv
(
x
)
if
self
.
use_res_connect
:
x
=
paddle
.
add
(
identity
,
x
)
return
x
class
MobileNetV3
(
nn
.
Layer
):
"""MobileNetV3 model from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
config (list[InvertedResidualConfig]): MobileNetV3 depthwise blocks config.
last_channel (int): The number of channels on the penultimate layer.
scale (float, optional): Scale of channels in each layer. Default: 1.0.
num_classes (int, optional): Output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool, optional): Use pool before the last fc layer or not. Default: True.
"""
def
__init__
(
self
,
config
,
last_channel
,
scale
=
1.0
,
num_classes
=
1000
,
with_pool
=
True
):
super
().
__init__
()
self
.
config
=
config
self
.
scale
=
scale
self
.
last_channel
=
last_channel
self
.
num_classes
=
num_classes
self
.
with_pool
=
with_pool
self
.
firstconv_in_channels
=
config
[
0
].
in_channels
self
.
lastconv_in_channels
=
config
[
-
1
].
in_channels
self
.
lastconv_out_channels
=
self
.
lastconv_in_channels
*
6
norm_layer
=
partial
(
nn
.
BatchNorm2D
,
epsilon
=
0.001
,
momentum
=
0.99
)
self
.
conv
=
ConvNormActivation
(
in_channels
=
3
,
out_channels
=
self
.
firstconv_in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
groups
=
1
,
activation_layer
=
nn
.
Hardswish
,
norm_layer
=
norm_layer
)
self
.
blocks
=
nn
.
Sequential
(
*
[
InvertedResidual
(
in_channels
=
cfg
.
in_channels
,
expanded_channels
=
cfg
.
expanded_channels
,
out_channels
=
cfg
.
out_channels
,
filter_size
=
cfg
.
kernel
,
stride
=
cfg
.
stride
,
use_se
=
cfg
.
use_se
,
activation_layer
=
cfg
.
activation_layer
,
norm_layer
=
norm_layer
)
for
cfg
in
self
.
config
])
self
.
lastconv
=
ConvNormActivation
(
in_channels
=
self
.
lastconv_in_channels
,
out_channels
=
self
.
lastconv_out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
1
,
norm_layer
=
norm_layer
,
activation_layer
=
nn
.
Hardswish
)
if
with_pool
:
self
.
avgpool
=
nn
.
AdaptiveAvgPool2D
(
1
)
if
num_classes
>
0
:
self
.
classifier
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
lastconv_out_channels
,
self
.
last_channel
),
nn
.
Hardswish
(),
nn
.
Dropout
(
p
=
0.2
),
nn
.
Linear
(
self
.
last_channel
,
num_classes
))
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
blocks
(
x
)
x
=
self
.
lastconv
(
x
)
if
self
.
with_pool
:
x
=
self
.
avgpool
(
x
)
if
self
.
num_classes
>
0
:
x
=
paddle
.
flatten
(
x
,
1
)
x
=
self
.
classifier
(
x
)
return
x
class
MobileNetV3Small
(
MobileNetV3
):
"""MobileNetV3 Small architecture model from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
scale (float, optional): Scale of channels in each layer. Default: 1.0.
num_classes (int, optional): Output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool, optional): Use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import MobileNetV3Small
# build model
model = MobileNetV3Small(scale=1.0)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
"""
def
__init__
(
self
,
scale
=
1.0
,
num_classes
=
1000
,
with_pool
=
True
):
config
=
[
InvertedResidualConfig
(
16
,
3
,
16
,
16
,
True
,
"relu"
,
2
,
scale
),
InvertedResidualConfig
(
16
,
3
,
72
,
24
,
False
,
"relu"
,
2
,
scale
),
InvertedResidualConfig
(
24
,
3
,
88
,
24
,
False
,
"relu"
,
1
,
scale
),
InvertedResidualConfig
(
24
,
5
,
96
,
40
,
True
,
"hardswish"
,
2
,
scale
),
InvertedResidualConfig
(
40
,
5
,
240
,
40
,
True
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
40
,
5
,
240
,
40
,
True
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
40
,
5
,
120
,
48
,
True
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
48
,
5
,
144
,
48
,
True
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
48
,
5
,
288
,
96
,
True
,
"hardswish"
,
2
,
scale
),
InvertedResidualConfig
(
96
,
5
,
576
,
96
,
True
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
96
,
5
,
576
,
96
,
True
,
"hardswish"
,
1
,
scale
),
]
last_channel
=
_make_divisible
(
1024
*
scale
,
8
)
super
().
__init__
(
config
,
last_channel
=
last_channel
,
scale
=
scale
,
with_pool
=
with_pool
,
num_classes
=
num_classes
)
class
MobileNetV3Large
(
MobileNetV3
):
"""MobileNetV3 Large architecture model from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
scale (float, optional): Scale of channels in each layer. Default: 1.0.
num_classes (int, optional): Output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool, optional): Use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import MobileNetV3Large
# build model
model = MobileNetV3Large(scale=1.0)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
"""
def
__init__
(
self
,
scale
=
1.0
,
num_classes
=
1000
,
with_pool
=
True
):
config
=
[
InvertedResidualConfig
(
16
,
3
,
16
,
16
,
False
,
"relu"
,
1
,
scale
),
InvertedResidualConfig
(
16
,
3
,
64
,
24
,
False
,
"relu"
,
2
,
scale
),
InvertedResidualConfig
(
24
,
3
,
72
,
24
,
False
,
"relu"
,
1
,
scale
),
InvertedResidualConfig
(
24
,
5
,
72
,
40
,
True
,
"relu"
,
2
,
scale
),
InvertedResidualConfig
(
40
,
5
,
120
,
40
,
True
,
"relu"
,
1
,
scale
),
InvertedResidualConfig
(
40
,
5
,
120
,
40
,
True
,
"relu"
,
1
,
scale
),
InvertedResidualConfig
(
40
,
3
,
240
,
80
,
False
,
"hardswish"
,
2
,
scale
),
InvertedResidualConfig
(
80
,
3
,
200
,
80
,
False
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
80
,
3
,
184
,
80
,
False
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
80
,
3
,
184
,
80
,
False
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
80
,
3
,
480
,
112
,
True
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
112
,
3
,
672
,
112
,
True
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
112
,
5
,
672
,
160
,
True
,
"hardswish"
,
2
,
scale
),
InvertedResidualConfig
(
160
,
5
,
960
,
160
,
True
,
"hardswish"
,
1
,
scale
),
InvertedResidualConfig
(
160
,
5
,
960
,
160
,
True
,
"hardswish"
,
1
,
scale
),
]
last_channel
=
_make_divisible
(
1280
*
scale
,
8
)
super
().
__init__
(
config
,
last_channel
=
last_channel
,
scale
=
scale
,
with_pool
=
with_pool
,
num_classes
=
num_classes
)
def
_mobilenet_v3
(
arch
,
pretrained
=
False
,
scale
=
1.0
,
**
kwargs
):
if
arch
==
"mobilenet_v3_large"
:
model
=
MobileNetV3Large
(
scale
=
scale
,
**
kwargs
)
else
:
model
=
MobileNetV3Small
(
scale
=
scale
,
**
kwargs
)
if
pretrained
:
arch
=
"{}_x{}"
.
format
(
arch
,
scale
)
assert
(
arch
in
model_urls
),
"{} model do not have a pretrained model now, you should set pretrained=False"
.
format
(
arch
)
weight_path
=
get_weights_path_from_url
(
model_urls
[
arch
][
0
],
model_urls
[
arch
][
1
])
param
=
paddle
.
load
(
weight_path
)
model
.
set_dict
(
param
)
return
model
def
mobilenet_v3_small
(
pretrained
=
False
,
scale
=
1.0
,
**
kwargs
):
"""MobileNetV3 Small architecture model from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale (float, optional): Scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v3_small
# build model
model = mobilenet_v3_small()
# build model and load imagenet pretrained weight
# model = mobilenet_v3_small(pretrained=True)
# build mobilenet v3 small model with scale=0.5
model = mobilenet_v3_small(scale=0.5)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
"""
model
=
_mobilenet_v3
(
"mobilenet_v3_small"
,
scale
=
scale
,
pretrained
=
pretrained
,
**
kwargs
)
return
model
def
mobilenet_v3_large
(
pretrained
=
False
,
scale
=
1.0
,
**
kwargs
):
"""MobileNetV3 Large architecture model from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
scale (float, optional): Scale of channels in each layer. Default: 1.0.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import mobilenet_v3_large
# build model
model = mobilenet_v3_large()
# build model and load imagenet pretrained weight
# model = mobilenet_v3_large(pretrained=True)
# build mobilenet v3 large model with scale=0.5
model = mobilenet_v3_large(scale=0.5)
x = paddle.rand([1, 3, 224, 224])
out = model(x)
print(out.shape)
"""
model
=
_mobilenet_v3
(
"mobilenet_v3_large"
,
scale
=
scale
,
pretrained
=
pretrained
,
**
kwargs
)
return
model
python/paddle/vision/models/utils.py
0 → 100644
浏览文件 @
68af310b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def
_make_divisible
(
v
,
divisor
=
8
,
min_value
=
None
):
"""
This function ensures that all layers have a channel number that is divisible by divisor
You can also see at https://github.com/keras-team/keras/blob/8ecef127f70db723c158dbe9ed3268b3d610ab55/keras/applications/mobilenet_v2.py#L505
Args:
divisor (int): The divisor for number of channels. Default: 8.
min_value (int, optional): The minimum value of number of channels, if it is None,
the default is divisor. Default: None.
"""
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
# Make sure that round down does not go down by more than 10%.
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
python/paddle/vision/ops.py
浏览文件 @
68af310b
...
...
@@ -17,7 +17,7 @@ from ..fluid.layer_helper import LayerHelper
from
..fluid.data_feeder
import
check_variable_and_dtype
,
check_type
,
check_dtype
from
..fluid
import
core
,
layers
from
..fluid.layers
import
nn
,
utils
from
..nn
import
Layer
from
..nn
import
Layer
,
Conv2D
,
Sequential
,
ReLU
,
BatchNorm2D
from
..fluid.initializer
import
Normal
from
paddle.common_ops_import
import
*
...
...
@@ -1297,3 +1297,57 @@ class RoIAlign(Layer):
output_size
=
self
.
_output_size
,
spatial_scale
=
self
.
_spatial_scale
,
aligned
=
aligned
)
class
ConvNormActivation
(
Sequential
):
"""
Configurable block used for Convolution-Normalzation-Activation blocks.
This code is based on the torchvision code with modifications.
You can also see at https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L68
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None,
in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., paddle.nn.Layer], optional): Norm layer that will be stacked on top of the convolutiuon layer.
If ``None`` this layer wont be used. Default: ``paddle.nn.BatchNorm2d``
activation_layer (Callable[..., paddle.nn.Layer], optional): Activation function which will be stacked on top of the normalization
layer (if not ``None``), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``paddle.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
None
,
groups
=
1
,
norm_layer
=
BatchNorm2D
,
activation_layer
=
ReLU
,
dilation
=
1
,
bias
=
None
):
if
padding
is
None
:
padding
=
(
kernel_size
-
1
)
//
2
*
dilation
if
bias
is
None
:
bias
=
norm_layer
is
None
layers
=
[
Conv2D
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
groups
,
bias_attr
=
bias
)
]
if
norm_layer
is
not
None
:
layers
.
append
(
norm_layer
(
out_channels
))
if
activation_layer
is
not
None
:
layers
.
append
(
activation_layer
())
super
().
__init__
(
*
layers
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录