提交 ea630846 编写于 作者: D dongshuilong

add_mixnet_rexnet

上级 42d2962d
......@@ -7,7 +7,8 @@
PaddleClas is a toolset for image classification tasks prepared for the industry and academia. It helps users train better computer vision models and apply them in real scenarios.
**Recent update**
- 2021.03.02 Add support for model quantization。
- 2021.04.15 Add `MixNet` and `ReXNet` pretrained models, `MixNet`'s Top-1 Acc on ImageNet-1k reaches 78.6% and `ReXNet` reaches 82.09%.
- 2021.03.02 Add support for model quantization.
- 2021.02.01 Add `RepVGG` pretrained models, whose Top-1 Acc on ImageNet-1k dataset reaches 79.65%.
- 2021.01.27 Add `ViT` and `DeiT` pretrained models, `ViT`'s Top-1 Acc on ImageNet-1k dataset reaches 85.13%, and `DeiT` reaches 85.1%.
- 2021.01.08 Add support for whl package and its usage, Model inference can be done by simply install paddleclas using pip.
......@@ -64,6 +65,8 @@ PaddleClas is a toolset for image classification tasks prepared for the industry
- [ResNeSt and RegNet series](#ResNeSt_and_RegNet_series)
- [Transformer series](#Transformer)
- [RepVGG series](#RepVGG)
- [MixNet series](#MixNet)
- [ReXNet series](#ReXNet)
- [Others](#Others)
- HS-ResNet: arxiv link: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf). Code and models are coming soon!
- Model training/evaluation
......@@ -401,6 +404,32 @@ Accuracy and inference time metrics of RepVGG series models are shown as follows
| RepVGG_B2g4 | 0.7881 | 0.9448 | | | | | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g4_pretrained.pdparams) |
| RepVGG_B3g4 | 0.7965 | 0.9485 | | | | | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g4_pretrained.pdparams) |
<a name="MixNet"></a>
### MixNet
Accuracy and inference time metrics of MixNet series models are shown as follows. More detailed information can be refered to [MixNet series tutorial](./docs/en/models/MixNet_en.md).
| Model | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(M) | Params(M) | Download Address |
| -------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
| MixNet_S | 0.7628 | 0.9299 | | | 252.977 | 4.167 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_S_pretrained.pdparams) |
| MixNet_M | 0.7767 | 0.9364 | | | 357.119 | 5.065 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_M_pretrained.pdparams) |
| MixNet_L | 0.7860 | 0.9437 | | | 579.017 | 7.384 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_L_pretrained.pdparams) |
<a name="ReXNet"></a>
### ReXNet
Accuracy and inference time metrics of ReXNet series models are shown as follows. More detailed information can be refered to [ReXNet series tutorial](./docs/en/models/ReXNet_en.md).
| Model | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | Download Address |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
| ReXNet_1_0 | 0.7746 | 0.9370 | | | 0.415 | 4.838 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_1_0_pretrained.pdparams) |
| ReXNet_1_3 | 0.7913 | 0.9464 | | | 0.683 | 7.611 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_1_3_pretrained.pdparams) |
| ReXNet_1_5 | 0.8006 | 0.9512 | | | 0.900 | 9.791 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_1_5_pretrained.pdparams) |
| ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) |
| ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) |
<a name="Others"></a>
### Others
......
......@@ -6,8 +6,9 @@
飞桨图像分类套件PaddleClas是飞桨为工业界和学术界所准备的一个图像分类任务的工具集,助力使用者训练出更好的视觉模型和应用落地。
**近期更新**
- 2021.04.15 添加`MixNet``ReXNet`系列模型,在ImageNet-1k上`MixNet` 模型Top1 Acc可达78.6%,`ReXNet`模型可达82.09%
- 2021.03.02 添加分类模型量化方法与使用教程。
- 2021.02.01 添加`RepVGG`系列模型,在ImageNet-1k上Top-1 Acc可达79.65%。
- 2021.01.27 添加`ViT``DeiT`模型,在ImageNet-1k上,`ViT`模型Top-1 Acc可达85.13%,`DeiT`模型可达85.1%。
......@@ -65,6 +66,8 @@
- [ResNeSt与RegNet系列](#ResNeSt与RegNet系列)
- [Transformer系列](#Transformer系列)
- [RepVGG系列](#RepVGG系列)
- [MixNet系列](#MixNet系列)
- [ReXNet系列](#ReXNet系列)
- [其他模型](#其他模型)
- HS-ResNet: arxiv文章链接: [https://arxiv.org/pdf/2010.07621.pdf](https://arxiv.org/pdf/2010.07621.pdf)。 代码和预训练模型即将开源,敬请期待。
- 模型训练/评估
......@@ -383,7 +386,6 @@ ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列
| DeiT_base_<br>distilled_patch16_384 | 0.851 | 0.973 | - | - | | 88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/DeiT_base_distilled_patch16_384_pretrained.pdparams) |
| | | | | | | | |
<a name="RepVGG系列"></a>
### RepVGG系列
......@@ -404,6 +406,32 @@ ViT(Vision Transformer)与DeiT(Data-efficient Image Transformers)系列
| RepVGG_B2g4 | 0.7881 | 0.9448 | | | | | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g4_pretrained.pdparams) |
| RepVGG_B3g4 | 0.7965 | 0.9485 | | | | | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g4_pretrained.pdparams) |
<a name="MixNet系列"></a>
### MixNet系列
关于MixNet系列模型的精度、速度指标如下表所示,更多介绍可以参考:[MixNet系列模型文档](./docs/zh_CN/models/MixNet.md)
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(M) | Params(M) | 下载地址 |
| -------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
| MixNet_S | 0.7628 | 0.9299 | | | 252.977 | 4.167 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_S_pretrained.pdparams) |
| MixNet_M | 0.7767 | 0.9364 | | | 357.119 | 5.065 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_M_pretrained.pdparams) |
| MixNet_L | 0.7860 | 0.9437 | | | 579.017 | 7.384 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MixNet_L_pretrained.pdparams) |
<a name="ReXNet系列"></a>
### ReXNet系列
关于ReXNet系列模型的精度、速度指标如下表所示,更多介绍可以参考:[ReXNet系列模型文档](./docs/zh_CN/models/ReXNet.md)
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | Flops(G) | Params(M) | 下载地址 |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ |
| ReXNet_1_0 | 0.7746 | 0.9370 | | | 0.415 | 4.838 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_1_0_pretrained.pdparams) |
| ReXNet_1_3 | 0.7913 | 0.9464 | | | 0.683 | 7.611 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_1_3_pretrained.pdparams) |
| ReXNet_1_5 | 0.8006 | 0.9512 | | | 0.900 | 9.791 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_1_5_pretrained.pdparams) |
| ReXNet_2_0 | 0.8122 | 0.9536 | | | 1.561 | 16.449 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_2_0_pretrained.pdparams) |
| ReXNet_3_0 | 0.8209 | 0.9612 | | | 3.445 | 34.833 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ReXNet_3_0_pretrained.pdparams) |
<a name="其他模型"></a>
### 其他模型
......
# MixNet series
## Overview
MixNet is a lightweight network proposed by Google. The main idea of MixNet is to explore the combination of different size of kernels. The author found that the current network has the following two problems:
- Small convolution kernel has small receptive field and few parameters, but the accuracy is not high.
- The larger convolution kernel has larger receptive field and higher accuracy, but the parameters also increase a lot .
In order to solve the above two problems, MDConv(mixed depthwise convolution) is proposed. In this method, different size of kernels are mixed in a convolution operation block. And based on AutoML, a series of networks called MixNets are proposed, which have achieved good results on Imagenet. [paper](https://arxiv.org/pdf/1907.09595.pdf)
## Accuracy, FLOPS and Parameters
| Models | Top1 | Top5 | Reference<br>top1 | FLOPS<br>(M) | Params<br/>(G |
| :------: | :---: | :---: | :---------------: | :----------: | ------------- |
| MixNet_S | 76.28 | 92.99 | 75.8 | 252.977 | 4.167 |
| MixNet_M | 77.67 | 93.64 | 77.0 | 357.119 | 5.065 |
| MixNet_L | 78.60 | 94.37 | 78.9 | 579.017 | 7.384 |
Inference speed and other information are coming soon.
# ReXNet series
## Overview
ReXNet is proposed by NAVER AI Lab, which is based on new network design principles. Aiming at the problem of representative bottleneck in the existing network, a set of design principles are proposed. The author believes that the conventional design produce representational bottlenecks, which would affect model performance. To investigate the representational bottleneck, the author study the matrix rank of the features generated by ten thousand random networks. Besides, entire layer’s channel configuration is also studied to design more accurate network architectures. In the end, the author proposes a set of simple and effective design principles to mitigate the representational bottleneck. [paper](https://arxiv.org/pdf/2007.00992.pdf)
## Accuracy, FLOPS and Parameters
| Models | Top1 | Top5 | Reference<br>top1 | FLOPS<br/>(G) | Params<br/>(M) |
| :--------: | :---: | :---: | :---------------: | :-----------: | -------------- |
| ReXNet_1_0 | 77.46 | 93.70 | 77.9 | 0.415 | 4.838 |
| ReXNet_1_3 | 79.13 | 94.64 | 79.5 | 0.683 | 7.611 |
| ReXNet_1_5 | 80.06 | 95.12 | 80.3 | 0.900 | 9.791 |
| ReXNet_2_0 | 81.22 | 95.36 | 81.6 | 1.561 | 16.449 |
| ReXNet_3_0 | 82.09 | 96.12 | 82.8 | 3.445 | 34.833 |
Inference speed and other information are coming soon.
......@@ -2,21 +2,18 @@
- 2021.04.15
- Add `MixNet` and `ReXNet` pretrained models, `MixNet`'s Top-1 Acc on ImageNet-1k reaches 78.6% and `ReXNet` reaches 82.09%.
- 2021.01.27
* Add ViT and DeiT pretrained models, ViT's Top-1 Acc on ImageNet reaches 81.05%, and DeiT reaches 85.5%.
- 2021.01.08
* Add support for whl package and its usage, Model inference can be done by simply install paddleclas using pip.
- 2020.12.16
* Add support for TensorRT when using cpp inference to obain more obvious acceleration.
- 2020.12.06
* Add `SE_HRNet_W64_C_ssld` pretrained model, whose Top-1 Acc on ImageNet1k dataset reaches 84.75%.
- 2020.11.23
* Add `GhostNet_x1_3_ssld` pretrained model, whose Top-1 Acc on ImageNet1k dataset reaches 79.38%.
- 2020.11.09
* Add `InceptionV3` architecture and pretrained model, whose Top-1 Acc on ImageNet1k dataset reaches 79.1%.
......
# MixNet系列
## 概述
MixNet是谷歌出的一篇关于轻量级网络的文章,主要工作就在于探索不同大小的卷积核的组合。作者发现目前网络有以下两个问题:
- 小的卷积核感受野小,参数少,但是准确率不高
- 大的卷积核感受野大,准确率相对略高,但是参数也相对增加了很多
为了解决上面两个问题,文中提出一种新的混合深度分离卷积(MDConv)(mixed depthwise convolution),将不同的核大小混合在一个卷积运算中,并且基于AutoML的搜索空间,提出了一系列的网络叫做MixNets,在ImageNet上取得了较好的效果。[论文地址](https://arxiv.org/pdf/1907.09595.pdf)
## 精度、FLOPS和参数量
| Models | Top1 | Top5 | Reference<br>top1| FLOPS<br>(M) | Params<br/>(M) |
|:--:|:--:|:--:|:--:|:--:|----|
| MixNet_S | 76.28 | 92.99 | 75.8 | 252.977 | 4.167 |
| MixNet_M | 77.67 | 93.64 | 77.0 | 357.119 | 5.065 |
| MixNet_L | 78.60 | 94.37 | 78.9 | 579.017 | 7.384 |
关于Inference speed等信息,敬请期待。
# ReXNet系列
## 概述
ReXNet是NAVER集团ClovaAI研发中心基于一种网络架构设计新范式而构建的网络。针对现有网络中存在的`Representational Bottleneck`问题,作者提出了一组新的设计原则。作者认为传统的网络架构设计范式会产生表达瓶颈,进而影响模型的性能。为研究此问题,作者研究了上万个随机网络生成特征的`matric rank`,同时进一步研究了网络层中通道配置方案。基于此,作者提出了一组简单而有效的设计原则,以消除表达瓶颈问题。[论文地址](https://arxiv.org/pdf/2007.00992.pdf)
## 精度、FLOPS和参数量
| Models | Top1 | Top5 | Reference<br>top1| FLOPS<br/>(G) | Params<br/>(M) |
|:--:|:--:|:--:|:--:|:--:|----|
| ReXNet_1_0 | 77.46 | 93.70 | 77.9 | 0.415 | 4.838 |
| ReXNet_1_3 | 79.13 | 94.64 | 79.5 | 0.683 | 7.611 |
| ReXNet_1_5 | 80.06 | 95.12 | 80.3 | 0.900 | 9.791 |
| ReXNet_2_0 | 81.22 | 95.36 | 81.6 | 1.561 | 16.449 |
| ReXNet_3_0 | 82.09 | 96.12 | 82.8 | 3.445 | 34.833 |
关于Inference speed等信息,敬请期待。
......@@ -2,34 +2,27 @@
- 2021.04.15
- 添加`MixNet``ReXNet`系列模型,在ImageNet-1k上`MixNet` 模型Top1 Acc可达78.6%,`ReXNet`模型可达82.09%
- 2021.01.27
* 添加ViT与DeiT模型,在ImageNet上,ViT模型Top-1 Acc可达81.05%,DeiT模型可达85.5%。
- 2021.01.08
* 添加whl包及其使用说明,直接安装paddleclas whl包,即可快速完成模型预测。
- 2020.12.16
* 添加对cpp预测的tensorRT支持,预测加速更明显。
- 2020.12.06
* 添加SE_HRNet_W64_C_ssld模型,在ImageNet上Top-1 Acc可达0.8475。
- 2020.11.23
* 添加GhostNet_x1_3_ssld模型,在ImageNet上Top-1 Acc可达0.7938。
- 2020.11.09
* 添加InceptionV3结构和模型,在ImageNet上Top-1 Acc可达0.791。
- 2020.10.20
* 添加Res2Net50_vd_26w_4s_ssld模型,在ImageNet上Top-1 Acc可达0.831;添加Res2Net101_vd_26w_4s_ssld模型,在ImageNet上Top-1 Acc可达0.839。
- 2020.10.12
* 添加Paddle-Lite demo。
- 2020.10.10
* 添加cpp inference demo。
* 添加FAQ30问。
- 2020.09.17
* 添加HRNet_W48_C_ssld模型,在ImageNet上Top-1 Acc可达0.836;添加ResNet34_vd_ssld模型,在ImageNet上Top-1 Acc可达0.797。
......
......@@ -47,4 +47,5 @@ from .vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT
from .distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
from .repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B3, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g2, RepVGG_B2g4, RepVGG_B3g2, RepVGG_B3g4
from .mixnet import MixNet_S, MixNet_M, MixNet_L
from .rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MixNet for ImageNet-1K, implemented in Paddle.
Original paper: 'MixConv: Mixed Depthwise Convolutional Kernels,'
https://arxiv.org/abs/1907.09595.
"""
__all__ = ['MixNet_S', 'MixNet_M', 'MixNet_L']
import os
from inspect import isfunction
from functools import reduce
import paddle
import paddle.nn as nn
class Identity(nn.Layer):
"""
Identity block.
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
def round_channels(channels, divisor=8):
"""
Round weighted channel number (make divisible operation).
Parameters:
----------
channels : int or float
Original number of channels.
divisor : int, default 8
Alignment value.
Returns:
-------
int
Weighted number of channels.
"""
rounded_channels = max(
int(channels + divisor / 2.0) // divisor * divisor, divisor)
if float(rounded_channels) < 0.9 * channels:
rounded_channels += divisor
return rounded_channels
def get_activation_layer(activation):
"""
Create activation layer from string/function.
Parameters:
----------
activation : function, or str, or nn.Module
Activation function or name of activation function.
Returns:
-------
nn.Module
Activation layer.
"""
assert activation is not None
if isfunction(activation):
return activation()
elif isinstance(activation, str):
if activation == "relu":
return nn.ReLU()
elif activation == "relu6":
return nn.ReLU6()
elif activation == "swish":
return nn.Swish()
elif activation == "hswish":
return nn.Hardswish()
elif activation == "sigmoid":
return nn.Sigmoid()
elif activation == "hsigmoid":
return nn.Hardsigmoid()
elif activation == "identity":
return Identity()
else:
raise NotImplementedError()
else:
assert isinstance(activation, nn.Layer)
return activation
class ConvBlock(nn.Layer):
"""
Standard convolution block with Batch normalization and activation.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple/list of 2 int
Convolution window size.
stride : int or tuple/list of 2 int
Strides of the convolution.
padding : int, or tuple/list of 2 int, or tuple/list of 4 int
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
use_bn : bool, default True
Whether to use BatchNorm layer.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU()
Activation function or name of activation function.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
groups=1,
bias=False,
use_bn=True,
bn_eps=1e-5,
activation=nn.ReLU()):
super(ConvBlock, self).__init__()
self.activate = (activation is not None)
self.use_bn = use_bn
self.use_pad = (isinstance(padding, (list, tuple)) and
(len(padding) == 4))
if self.use_pad:
self.pad = padding
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias,
weight_attr=None)
if self.use_bn:
self.bn = nn.BatchNorm2D(num_features=out_channels, epsilon=bn_eps)
if self.activate:
self.activ = get_activation_layer(activation)
def forward(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.activate:
x = self.activ(x)
return x
class SEBlock(nn.Layer):
def __init__(self,
channels,
reduction=16,
mid_channels=None,
round_mid=False,
use_conv=True,
mid_activation=nn.ReLU(),
out_activation=nn.Sigmoid()):
super(SEBlock, self).__init__()
self.use_conv = use_conv
if mid_channels is None:
mid_channels = channels // reduction if not round_mid else round_channels(
float(channels) / reduction)
self.pool = nn.AdaptiveAvgPool2D(output_size=1)
if use_conv:
self.conv1 = nn.Conv2D(
in_channels=channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
groups=1,
bias_attr=True,
weight_attr=None)
else:
self.fc1 = nn.Linear(
in_features=channels, out_features=mid_channels)
self.activ = get_activation_layer(mid_activation)
if use_conv:
self.conv2 = nn.Conv2D(
in_channels=mid_channels,
out_channels=channels,
kernel_size=1,
stride=1,
groups=1,
bias_attr=True,
weight_attr=None)
else:
self.fc2 = nn.Linear(
in_features=mid_channels, out_features=channels)
self.sigmoid = get_activation_layer(out_activation)
def forward(self, x):
w = self.pool(x)
if not self.use_conv:
w = w.reshape(shape=[w.shape[0], -1])
w = self.conv1(w) if self.use_conv else self.fc1(w)
w = self.activ(w)
w = self.conv2(w) if self.use_conv else self.fc2(w)
w = self.sigmoid(w)
if not self.use_conv:
w = w.unsqueeze(2).unsqueeze(3)
x = x * w
return x
class MixConv(nn.Layer):
"""
Mixed convolution layer from 'MixConv: Mixed Depthwise Convolutional Kernels,'
https://arxiv.org/abs/1907.09595.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple/list of int, or tuple/list of tuple/list of 2 int
Convolution window size.
stride : int or tuple/list of 2 int
Strides of the convolution.
padding : int or tuple/list of int, or tuple/list of tuple/list of 2 int
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
axis : int, default 1
The axis on which to concatenate the outputs.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
groups=1,
bias=False,
axis=1):
super(MixConv, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size,
list) else [kernel_size]
padding = padding if isinstance(padding, list) else [padding]
kernel_count = len(kernel_size)
self.splitted_in_channels = self.split_channels(in_channels,
kernel_count)
splitted_out_channels = self.split_channels(out_channels, kernel_count)
for i, kernel_size_i in enumerate(kernel_size):
in_channels_i = self.splitted_in_channels[i]
out_channels_i = splitted_out_channels[i]
padding_i = padding[i]
_ = self.add_sublayer(
name=str(i),
sublayer=nn.Conv2D(
in_channels=in_channels_i,
out_channels=out_channels_i,
kernel_size=kernel_size_i,
stride=stride,
padding=padding_i,
dilation=dilation,
groups=(out_channels_i
if out_channels == groups else groups),
bias_attr=bias,
weight_attr=None))
self.axis = axis
def forward(self, x):
xx = paddle.split(x, self.splitted_in_channels, axis=self.axis)
xx = paddle.split(x, self.splitted_in_channels, axis=self.axis)
out = [
conv_i(x_i) for x_i, conv_i in zip(xx, self._sub_layers.values())
]
x = paddle.concat(tuple(out), axis=self.axis)
return x
@staticmethod
def split_channels(channels, kernel_count):
splitted_channels = [channels // kernel_count] * kernel_count
splitted_channels[0] += channels - sum(splitted_channels)
return splitted_channels
class MixConvBlock(nn.Layer):
"""
Mixed convolution block with Batch normalization and activation.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple/list of int, or tuple/list of tuple/list of 2 int
Convolution window size.
stride : int or tuple/list of 2 int
Strides of the convolution.
padding : int or tuple/list of int, or tuple/list of tuple/list of 2 int
Padding value for convolution layer.
dilation : int or tuple/list of 2 int, default 1
Dilation value for convolution layer.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
use_bn : bool, default True
Whether to use BatchNorm layer.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str or None, default nn.ReLU()
Activation function or name of activation function.
activate : bool, default True
Whether activate the convolution block.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
groups=1,
bias=False,
use_bn=True,
bn_eps=1e-5,
activation=nn.ReLU()):
super(MixConvBlock, self).__init__()
self.activate = (activation is not None)
self.use_bn = use_bn
self.conv = MixConv(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
if self.use_bn:
self.bn = nn.BatchNorm2D(num_features=out_channels, epsilon=bn_eps)
if self.activate:
self.activ = get_activation_layer(activation)
def forward(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
if self.activate:
x = self.activ(x)
return x
def mixconv1x1_block(in_channels,
out_channels,
kernel_count,
stride=1,
groups=1,
bias=False,
use_bn=True,
bn_eps=1e-5,
activation=nn.ReLU()):
"""
1x1 version of the mixed convolution block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_count : int
Kernel count.
stride : int or tuple/list of 2 int, default 1
Strides of the convolution.
groups : int, default 1
Number of groups.
bias : bool, default False
Whether the layer uses a bias vector.
use_bn : bool, default True
Whether to use BatchNorm layer.
bn_eps : float, default 1e-5
Small float added to variance in Batch norm.
activation : function or str, or None, default nn.ReLU()
Activation function or name of activation function.
"""
return MixConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=([1] * kernel_count),
stride=stride,
padding=([0] * kernel_count),
groups=groups,
bias=bias,
use_bn=use_bn,
bn_eps=bn_eps,
activation=activation)
class MixUnit(nn.Layer):
"""
MixNet unit.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
exp_channels : int
Number of middle (expanded) channels.
stride : int or tuple/list of 2 int
Strides of the second convolution layer.
exp_kernel_count : int
Expansion convolution kernel count for each unit.
conv1_kernel_count : int
Conv1 kernel count for each unit.
conv2_kernel_count : int
Conv2 kernel count for each unit.
exp_factor : int
Expansion factor for each unit.
se_factor : int
SE reduction factor for each unit.
activation : str
Activation function or name of activation function.
"""
def __init__(self, in_channels, out_channels, stride, exp_kernel_count,
conv1_kernel_count, conv2_kernel_count, exp_factor, se_factor,
activation):
super(MixUnit, self).__init__()
assert exp_factor >= 1
assert se_factor >= 0
self.residual = (in_channels == out_channels) and (stride == 1)
self.use_se = se_factor > 0
mid_channels = exp_factor * in_channels
self.use_exp_conv = exp_factor > 1
if self.use_exp_conv:
if exp_kernel_count == 1:
self.exp_conv = ConvBlock(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
bias=False,
use_bn=True,
bn_eps=1e-5,
activation=activation)
else:
self.exp_conv = mixconv1x1_block(
in_channels=in_channels,
out_channels=mid_channels,
kernel_count=exp_kernel_count,
activation=activation)
if conv1_kernel_count == 1:
self.conv1 = ConvBlock(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
stride=stride,
padding=1,
dilation=1,
groups=mid_channels,
bias=False,
use_bn=True,
bn_eps=1e-5,
activation=activation)
else:
self.conv1 = MixConvBlock(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=[3 + 2 * i for i in range(conv1_kernel_count)],
stride=stride,
padding=[1 + i for i in range(conv1_kernel_count)],
groups=mid_channels,
activation=activation)
if self.use_se:
self.se = SEBlock(
channels=mid_channels,
reduction=(exp_factor * se_factor),
round_mid=False,
mid_activation=activation)
if conv2_kernel_count == 1:
self.conv2 = ConvBlock(
in_channels=mid_channels,
out_channels=out_channels,
activation=None,
kernel_size=1,
stride=1,
padding=0,
groups=1,
bias=False,
use_bn=True,
bn_eps=1e-5)
else:
self.conv2 = mixconv1x1_block(
in_channels=mid_channels,
out_channels=out_channels,
kernel_count=conv2_kernel_count,
activation=None)
def forward(self, x):
if self.residual:
identity = x
if self.use_exp_conv:
x = self.exp_conv(x)
x = self.conv1(x)
if self.use_se:
x = self.se(x)
x = self.conv2(x)
if self.residual:
x = x + identity
return x
class MixInitBlock(nn.Layer):
"""
MixNet specific initial block.
Parameters:
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
def __init__(self, in_channels, out_channels):
super(MixInitBlock, self).__init__()
self.conv1 = ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
stride=2,
kernel_size=3,
padding=1)
self.conv2 = MixUnit(
in_channels=out_channels,
out_channels=out_channels,
stride=1,
exp_kernel_count=1,
conv1_kernel_count=1,
conv2_kernel_count=1,
exp_factor=1,
se_factor=0,
activation="relu")
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class MixNet(nn.Layer):
"""
MixNet model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
https://arxiv.org/abs/1907.09595.
Parameters:
----------
channels : list of list of int
Number of output channels for each unit.
init_block_channels : int
Number of output channels for the initial unit.
final_block_channels : int
Number of output channels for the final block of the feature extractor.
exp_kernel_counts : list of list of int
Expansion convolution kernel count for each unit.
conv1_kernel_counts : list of list of int
Conv1 kernel count for each unit.
conv2_kernel_counts : list of list of int
Conv2 kernel count for each unit.
exp_factors : list of list of int
Expansion factor for each unit.
se_factors : list of list of int
SE reduction factor for each unit.
in_channels : int, default 3
Number of input channels.
in_size : tuple of two ints, default (224, 224)
Spatial size of the expected input image.
class_dim : int, default 1000
Number of classification classes.
"""
def __init__(self,
channels,
init_block_channels,
final_block_channels,
exp_kernel_counts,
conv1_kernel_counts,
conv2_kernel_counts,
exp_factors,
se_factors,
in_channels=3,
in_size=(224, 224),
class_dim=1000):
super(MixNet, self).__init__()
self.in_size = in_size
self.class_dim = class_dim
self.features = nn.Sequential()
self.features.add_sublayer(
"init_block",
MixInitBlock(
in_channels=in_channels, out_channels=init_block_channels))
in_channels = init_block_channels
for i, channels_per_stage in enumerate(channels):
stage = nn.Sequential()
for j, out_channels in enumerate(channels_per_stage):
stride = 2 if ((j == 0) and (i != 3)) or (
(j == len(channels_per_stage) // 2) and (i == 3)) else 1
exp_kernel_count = exp_kernel_counts[i][j]
conv1_kernel_count = conv1_kernel_counts[i][j]
conv2_kernel_count = conv2_kernel_counts[i][j]
exp_factor = exp_factors[i][j]
se_factor = se_factors[i][j]
activation = "relu" if i == 0 else "swish"
stage.add_sublayer(
"unit{}".format(j + 1),
MixUnit(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
exp_kernel_count=exp_kernel_count,
conv1_kernel_count=conv1_kernel_count,
conv2_kernel_count=conv2_kernel_count,
exp_factor=exp_factor,
se_factor=se_factor,
activation=activation))
in_channels = out_channels
self.features.add_sublayer("stage{}".format(i + 1), stage)
self.features.add_sublayer(
"final_block",
ConvBlock(
in_channels=in_channels,
out_channels=final_block_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
bias=False,
use_bn=True,
bn_eps=1e-5,
activation=nn.ReLU()))
in_channels = final_block_channels
self.features.add_sublayer(
"final_pool", nn.AvgPool2D(
kernel_size=7, stride=1))
self.output = nn.Linear(
in_features=in_channels, out_features=class_dim)
def forward(self, x):
x = self.features(x)
reshape_dim = reduce(lambda x, y: x * y, x.shape[1:])
x = x.reshape(shape=[x.shape[0], reshape_dim])
x = self.output(x)
return x
def get_mixnet(version,
width_scale,
model_name=None,
pretrained=False,
root=os.path.join("~", ".paddle", "models"),
**kwargs):
"""
Create MixNet model with specific parameters.
Parameters:
----------
version : str
Version of MobileNetV3 ('s' or 'm').
width_scale : float
Scale factor for width of layers.
model_name : str or None, default None
Model name for loading pretrained model.
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.torch/models'
Location for keeping the model parameters.
"""
if version == "s":
init_block_channels = 16
channels = [[24, 24], [40, 40, 40, 40], [80, 80, 80],
[120, 120, 120, 200, 200, 200]]
exp_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 1, 1],
[2, 2, 2, 1, 1, 1]]
conv1_kernel_counts = [[1, 1], [3, 2, 2, 2], [3, 2, 2],
[3, 4, 4, 5, 4, 4]]
conv2_kernel_counts = [[2, 2], [1, 2, 2, 2], [2, 2, 2],
[2, 2, 2, 1, 2, 2]]
exp_factors = [[6, 3], [6, 6, 6, 6], [6, 6, 6], [6, 3, 3, 6, 6, 6]]
se_factors = [[0, 0], [2, 2, 2, 2], [4, 4, 4], [2, 2, 2, 2, 2, 2]]
elif version == "m":
init_block_channels = 24
channels = [[32, 32], [40, 40, 40, 40], [80, 80, 80, 80],
[120, 120, 120, 120, 200, 200, 200, 200]]
exp_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 2, 2, 2],
[1, 2, 2, 2, 1, 1, 1, 1]]
conv1_kernel_counts = [[3, 1], [4, 2, 2, 2], [3, 4, 4, 4],
[1, 4, 4, 4, 4, 4, 4, 4]]
conv2_kernel_counts = [[2, 2], [1, 2, 2, 2], [1, 2, 2, 2],
[1, 2, 2, 2, 1, 2, 2, 2]]
exp_factors = [[6, 3], [6, 6, 6, 6], [6, 6, 6, 6],
[6, 3, 3, 3, 6, 6, 6, 6]]
se_factors = [[0, 0], [2, 2, 2, 2], [4, 4, 4, 4],
[2, 2, 2, 2, 2, 2, 2, 2]]
else:
raise ValueError("Unsupported MixNet version {}".format(version))
final_block_channels = 1536
if width_scale != 1.0:
channels = [[round_channels(cij * width_scale) for cij in ci]
for ci in channels]
init_block_channels = round_channels(init_block_channels * width_scale)
net = MixNet(
channels=channels,
init_block_channels=init_block_channels,
final_block_channels=final_block_channels,
exp_kernel_counts=exp_kernel_counts,
conv1_kernel_counts=conv1_kernel_counts,
conv2_kernel_counts=conv2_kernel_counts,
exp_factors=exp_factors,
se_factors=se_factors,
**kwargs)
return net
def MixNet_S(**kwargs):
"""
MixNet-S model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
https://arxiv.org/abs/1907.09595.
Parameters:
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.torch/models'
Location for keeping the model parameters.
"""
return get_mixnet(
version="s", width_scale=1.0, model_name="MixNet_S", **kwargs)
def MixNet_M(**kwargs):
"""
MixNet-M model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
https://arxiv.org/abs/1907.09595.
Parameters:
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.torch/models'
Location for keeping the model parameters.
"""
return get_mixnet(
version="m", width_scale=1.0, model_name="MixNet_M", **kwargs)
def MixNet_L(**kwargs):
"""
MixNet-L model from 'MixConv: Mixed Depthwise Convolutional Kernels,'
https://arxiv.org/abs/1907.09595.
Parameters:
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.torch/models'
Location for keeping the model parameters.
"""
return get_mixnet(
version="m", width_scale=1.3, model_name="MixNet_L", **kwargs)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
from math import ceil
__all__ = [
"ReXNet_1_0", "ReXNet_1_3", "ReXNet_1_5", "ReXNet_2_0", "ReXNet_3_0"
]
def conv_bn_act(out,
in_channels,
channels,
kernel=1,
stride=1,
pad=0,
num_group=1,
active=True,
relu6=False):
out.append(
nn.Conv2D(
in_channels,
channels,
kernel,
stride,
pad,
groups=num_group,
bias_attr=False))
out.append(nn.BatchNorm2D(channels))
if active:
out.append(nn.ReLU6() if relu6 else nn.ReLU())
def conv_bn_swish(out,
in_channels,
channels,
kernel=1,
stride=1,
pad=0,
num_group=1):
out.append(
nn.Conv2D(
in_channels,
channels,
kernel,
stride,
pad,
groups=num_group,
bias_attr=False))
out.append(nn.BatchNorm2D(channels))
out.append(nn.Swish())
class SE(nn.Layer):
def __init__(self, in_channels, channels, se_ratio=12):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(
nn.Conv2D(
in_channels, channels // se_ratio, kernel_size=1, padding=0),
nn.BatchNorm2D(channels // se_ratio),
nn.ReLU(),
nn.Conv2D(
channels // se_ratio, channels, kernel_size=1, padding=0),
nn.Sigmoid())
def forward(self, x):
y = self.avg_pool(x)
y = self.fc(y)
return x * y
class LinearBottleneck(nn.Layer):
def __init__(self,
in_channels,
channels,
t,
stride,
use_se=True,
se_ratio=12,
**kwargs):
super(LinearBottleneck, self).__init__(**kwargs)
self.use_shortcut = stride == 1 and in_channels <= channels
self.in_channels = in_channels
self.out_channels = channels
out = []
if t != 1:
dw_channels = in_channels * t
conv_bn_swish(out, in_channels=in_channels, channels=dw_channels)
else:
dw_channels = in_channels
conv_bn_act(
out,
in_channels=dw_channels,
channels=dw_channels,
kernel=3,
stride=stride,
pad=1,
num_group=dw_channels,
active=False)
if use_se:
out.append(SE(dw_channels, dw_channels, se_ratio))
out.append(nn.ReLU6())
conv_bn_act(
out,
in_channels=dw_channels,
channels=channels,
active=False,
relu6=True)
self.out = nn.Sequential(*out)
def forward(self, x):
out = self.out(x)
if self.use_shortcut:
out[:, 0:self.in_channels] += x
return out
class ReXNetV1(nn.Layer):
def __init__(self,
input_ch=16,
final_ch=180,
width_mult=1.0,
depth_mult=1.0,
class_dim=1000,
use_se=True,
se_ratio=12,
dropout_ratio=0.2,
bn_momentum=0.9):
super(ReXNetV1, self).__init__()
layers = [1, 2, 2, 3, 3, 5]
strides = [1, 2, 2, 2, 1, 2]
use_ses = [False, False, True, True, True, True]
layers = [ceil(element * depth_mult) for element in layers]
strides = sum([[element] + [1] * (layers[idx] - 1)
for idx, element in enumerate(strides)], [])
if use_se:
use_ses = sum([[element] * layers[idx]
for idx, element in enumerate(use_ses)], [])
else:
use_ses = [False] * sum(layers[:])
ts = [1] * layers[0] + [6] * sum(layers[1:])
self.depth = sum(layers[:]) * 3
stem_channel = 32 / width_mult if width_mult < 1.0 else 32
inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch
features = []
in_channels_group = []
channels_group = []
# The following channel configuration is a simple instance to make each layer become an expand layer.
for i in range(self.depth // 3):
if i == 0:
in_channels_group.append(int(round(stem_channel * width_mult)))
channels_group.append(int(round(inplanes * width_mult)))
else:
in_channels_group.append(int(round(inplanes * width_mult)))
inplanes += final_ch / (self.depth // 3 * 1.0)
channels_group.append(int(round(inplanes * width_mult)))
conv_bn_swish(
features,
3,
int(round(stem_channel * width_mult)),
kernel=3,
stride=2,
pad=1)
for block_idx, (in_c, c, t, s, se) in enumerate(
zip(in_channels_group, channels_group, ts, strides, use_ses)):
features.append(
LinearBottleneck(
in_channels=in_c,
channels=c,
t=t,
stride=s,
use_se=se,
se_ratio=se_ratio))
pen_channels = int(1280 * width_mult)
conv_bn_swish(features, c, pen_channels)
features.append(nn.AdaptiveAvgPool2D(1))
self.features = nn.Sequential(*features)
self.output = nn.Sequential(
nn.Dropout(dropout_ratio),
nn.Conv2D(
pen_channels, class_dim, 1, bias_attr=True))
def forward(self, x):
x = self.features(x)
x = self.output(x).squeeze(axis=-1).squeeze(axis=-1)
return x
def ReXNet_1_0(**args):
return ReXNetV1(width_mult=1.0, **args)
def ReXNet_1_3(**args):
return ReXNetV1(width_mult=1.3, **args)
def ReXNet_1_5(**args):
return ReXNetV1(width_mult=1.5, **args)
def ReXNet_2_0(**args):
return ReXNetV1(width_mult=2.0, **args)
def ReXNet_3_0(**args):
return ReXNetV1(width_mult=3.0, **args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册