diff --git a/README.md b/README.md
index 965d1b3dc5c3ad68875ddf77a0d8e73419b05e71..d69646e9ce553f2c0a178c639e149bd82b9af789 100644
--- a/README.md
+++ b/README.md
@@ -100,10 +100,6 @@ PaddleClas的安装说明、模型训练、预测、评估以及模型微调(f
近年来,学术界和工业界广泛关注图像中目标检测任务,而图像分类的网络结构以及预训练模型效果直接影响目标检测的效果。PaddleDetection使用PaddleClas的82.39%的ResNet50_vd的预训练模型,结合自身丰富的检测算子,提供了一种面向服务器端应用的目标检测方案,PSS-DET (Practical Server Side Detection)。该方案融合了多种只增加少许计算量,但是可以有效提升两阶段Faster RCNN目标检测效果的策略,包括检测模型剪裁、使用分类效果更优的预训练模型、DCNv2、Cascade RCNN、AutoAugment、Libra sampling以及多尺度训练。其中基于82.39%的R50_vd_ssld预训练模型,与79.12%的R50_vd的预训练模型相比,检测效果可以提升1.5%。在COCO目标检测数据集上测试PSS-DET,当V100单卡预测速度为61FPS时,mAP是41.6%,预测速度为20FPS时,mAP是47.8%。详情请参考[**通用目标检测章节**](https://paddleclas.readthedocs.io/zh_CN/latest/application/object_detection.html)。
-- TODO
-- [ ] PaddleClas在OCR任务中的应用
-- [ ] PaddleClas在人脸检测和识别中的应用
-
## 工业级应用部署工具
PaddlePaddle提供了一系列实用工具,便于工业应用部署PaddleClas,具体请参考文档教程中的[**实用工具章节**](https://paddleclas.readthedocs.io/zh_CN/latest/extension/index.html)。
diff --git a/configs/EfficientNet/EfficientLite0.yaml b/configs/EfficientNet/EfficientLite0.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1f3e739a08ee8c29927b18956dfc7ad31298b184
--- /dev/null
+++ b/configs/EfficientNet/EfficientLite0.yaml
@@ -0,0 +1,91 @@
+mode: 'train'
+ARCHITECTURE:
+ name: "EfficientNetLite0"
+ params:
+ is_test: False
+ padding_type : "SAME"
+ override_params:
+ drop_connect_rate: 0.1
+ fix_head_stem: True
+ relu_fn: True
+
+pretrained_model: ""
+model_save_dir: "./output/"
+classes_num: 1000
+total_images: 1281167
+save_interval: 1
+validate: True
+valid_interval: 1
+epochs: 360
+topk: 5
+image_shape: [3, 224, 224]
+use_ema: True
+ema_decay: 0.9999
+use_aa: True
+ls_epsilon: 0.1
+
+LEARNING_RATE:
+ function: 'ExponentialWarmup'
+ params:
+ lr: 0.032
+
+OPTIMIZER:
+ function: 'RMSProp'
+ params:
+ momentum: 0.9
+ rho: 0.9
+ epsilon: 0.001
+ regularizer:
+ function: 'L2'
+ factor: 0.00001
+
+TRAIN:
+ batch_size: 512
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/train_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ to_np: False
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ interpolation: 1
+ - RandFlipImage:
+ flip_code: 1
+ - AutoAugment:
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+
+
+
+VALID:
+ batch_size: 128
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/val_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ to_np: False
+ channel_first: False
+ - ResizeImage:
+ interpolation: 1
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+
+
diff --git a/configs/EfficientNet/EfficientNetB0.yaml b/configs/EfficientNet/EfficientNetB0.yaml
index 01932d43a591f300d0aa7e21784c41ef94479572..0bf2b13cdda859d1a88dfe140d7d89d79b4dbd2b 100644
--- a/configs/EfficientNet/EfficientNetB0.yaml
+++ b/configs/EfficientNet/EfficientNetB0.yaml
@@ -19,7 +19,6 @@ topk: 5
image_shape: [3, 224, 224]
use_ema: True
ema_decay: 0.9999
-use_aa: True
ls_epsilon: 0.1
LEARNING_RATE:
@@ -46,7 +45,7 @@ TRAIN:
transforms:
- DecodeImage:
to_rgb: True
- to_np: Fals
+ to_np: False
channel_first: False
- RandCropImage:
size: 224
@@ -85,5 +84,3 @@ VALID:
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
-
-
diff --git a/configs/RegNet/RegNetX_4GF.yaml b/configs/RegNet/RegNetX_4GF.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..21001b534c8eaac86eb39eaaafea1ae931201c58
--- /dev/null
+++ b/configs/RegNet/RegNetX_4GF.yaml
@@ -0,0 +1,75 @@
+mode: 'train'
+ARCHITECTURE:
+ name: 'RegNetX_4GF'
+
+pretrained_model: ""
+model_save_dir: "./output/"
+classes_num: 1000
+total_images: 1281167
+save_interval: 1
+validate: True
+valid_interval: 1
+epochs: 100
+topk: 5
+image_shape: [3, 224, 224]
+
+use_mix: False
+ls_epsilon: -1
+
+LEARNING_RATE:
+ function: 'CosineWarmup'
+ params:
+ lr: 0.4
+ warmup_epoch: 5
+
+OPTIMIZER:
+ function: 'Momentum'
+ params:
+ momentum: 0.9
+ regularizer:
+ function: 'L2'
+ factor: 0.000050
+
+TRAIN:
+ batch_size: 512
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/train_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ to_np: False
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+
+VALID:
+ batch_size: 256
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/val_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ to_np: False
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
diff --git a/docs/en/extension/paddle_inference_en.md b/docs/en/extension/paddle_inference_en.md
index 3f0f29830aa0e612c122d831f3c12cb015874f23..14385b62d91e3677af52887ae48e14b4a2608f89 100644
--- a/docs/en/extension/paddle_inference_en.md
+++ b/docs/en/extension/paddle_inference_en.md
@@ -100,7 +100,7 @@ python tools/export_model.py \
The complete example is provided in the `tools/infer/predict.py`,just execute the following command to complete the prediction:
```
-python ./predict.py \
+python ./tools/infer/predict.py \
-i=./test.jpeg \
-m=./resnet50-vd/model \
-p=./resnet50-vd/params \
diff --git a/docs/en/models/Mobile_en.md b/docs/en/models/Mobile_en.md
index 469a1643a853733be3c9536cd049d85e265ca710..e6f7c629577997eb93db80aeb673bcba6fb848b1 100644
--- a/docs/en/models/Mobile_en.md
+++ b/docs/en/models/Mobile_en.md
@@ -10,6 +10,7 @@ The ShuffleNet series network is the lightweight network structure proposed by M
MobileNetV3 is a new and lightweight network based on NAS proposed by Google in 2019. In order to further improve the effect, the activation functions of relu and sigmoid were replaced with hard_swish and hard_sigmoid activation functions, and some improved strategies were introduced to reduce the amount of network computing.
+GhosttNet is a brand-new lightweight network structure proposed by Huawei in 2020. By introducing the ghost module, the problem of redundant calculation of features in traditional deep networks is greatly alleviated, which greatly reduces the amount of network parameters and calculations.
![](../../images/models/mobile_arm_top1.png)
@@ -57,6 +58,9 @@ Currently there are 32 pretrained models of the mobile series open source by Pad
| ShuffleNetV2_x1_5 | 0.716 | 0.902 | 0.726 | | 0.580 | 3.470 |
| ShuffleNetV2_x2_0 | 0.732 | 0.912 | 0.749 | | 1.120 | 7.320 |
| ShuffleNetV2_swish | 0.700 | 0.892 | | | 0.290 | 2.260 |
+| GhostNet_x0_5 | 0.668 | 0.869 | 0.662 | 0.866 | 0.041 | 2.600 |
+| GhostNet_x1_0 | 0.740 | 0.916 | 0.739 | 0.914 | 0.147 | 5.200 |
+| GhostNet_x1_3 | 0.757 | 0.925 | 0.757 | 0.927 | 0.220 | 7.300 |
## Inference speed and storage size based on SD855
diff --git a/docs/en/models/models_intro_en.md b/docs/en/models/models_intro_en.md
index a6d67186c0cc4e8063572fb70e838bbbf6166fe8..958b6f06233e4aeb9956bd5c2935a3f4c02f4c46 100644
--- a/docs/en/models/models_intro_en.md
+++ b/docs/en/models/models_intro_en.md
@@ -93,6 +93,10 @@ python tools/infer/predict.py \
- [ShuffleNetV2_x1_5](https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_x1_5_pretrained.tar)
- [ShuffleNetV2_x2_0](https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_x2_0_pretrained.tar)
- [ShuffleNetV2_swish](https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_swish_pretrained.tar)
+ - GhostNet series[[23](#ref23)]([paper link](https://arxiv.org/pdf/1911.11907.pdf))
+ - [GhostNet_x0_5](https://paddle-imagenet-models-name.bj.bcebos.com/GhostNet_x0_5_pretrained.pdparams)
+ - [GhostNet_x1_0](https://paddle-imagenet-models-name.bj.bcebos.com/GhostNet_x1_0_pretrained.pdparams)
+ - [GhostNet_x1_3](https://paddle-imagenet-models-name.bj.bcebos.com/GhostNet_x1_3_pretrained.pdparams)
- SEResNeXt and Res2Net series
@@ -254,3 +258,5 @@ python tools/infer/predict.py \
[21] Redmon J, Divvala S, Girshick R, et al. You only look once: Unified, real-time object detection[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 779-788.
[22] Ding X, Guo Y, Ding G, et al. Acnet: Strengthening the kernel skeletons for powerful cnn via asymmetric convolution blocks[C]//Proceedings of the IEEE International Conference on Computer Vision. 2019: 1911-1920.
+
+[23] Han K, Wang Y, Tian Q, et al. GhostNet: More features from cheap operations[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020: 1580-1589.
\ No newline at end of file
diff --git a/docs/zh_CN/extension/paddle_inference.md b/docs/zh_CN/extension/paddle_inference.md
index 4490c8a1b0d0157466752d4a3dec043e6c5d41cd..17300278a025c7d1ef28350d6e25c5d7d4acd609 100644
--- a/docs/zh_CN/extension/paddle_inference.md
+++ b/docs/zh_CN/extension/paddle_inference.md
@@ -100,7 +100,7 @@ python tools/export_model.py \
在模型库的 `tools/infer/predict.py` 中提供了完整的示例,只需执行下述命令即可完成预测:
```
-python ./predict.py \
+python ./tools/infer/predict.py \
-i=./test.jpeg \
-m=./resnet50-vd/model \
-p=./resnet50-vd/params \
@@ -122,7 +122,7 @@ python ./predict.py \
注意:
当启用benchmark时,默认开启tersorrt进行预测
-
+
构建预测引擎:
@@ -259,4 +259,3 @@ outputs = exe.run(compiled_program,
```
上述执行预测时候的参数说明可以参考官网 [fluid.Executor](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/executor_cn/Executor_cn.html)
-
diff --git a/docs/zh_CN/models/Mobile.md b/docs/zh_CN/models/Mobile.md
index 3c0ebe37693b5564fbc654d99948d69b04b97a85..964afe5d69c2afc3bc35460797250f276dfc995f 100644
--- a/docs/zh_CN/models/Mobile.md
+++ b/docs/zh_CN/models/Mobile.md
@@ -9,6 +9,8 @@ ShuffleNet系列网络是旷视提出的轻量化网络结构,到目前为止
MobileNetV3是Google于2019年提出的一种基于NAS的新的轻量级网络,为了进一步提升效果,将relu和sigmoid激活函数分别替换为hard_swish与hard_sigmoid激活函数,同时引入了一些专门减小网络计算量的改进策略。
+GhosttNet是华为于2020年提出的一种全新的轻量化网络结构,通过引入ghost module,大大减缓了传统深度网络中特征的冗余计算问题,使得网络的参数量和计算量大大降低。
+
![](../../images/models/mobile_arm_top1.png)
![](../../images/models/mobile_arm_storage.png)
@@ -18,7 +20,7 @@ MobileNetV3是Google于2019年提出的一种基于NAS的新的轻量级网络
![](../../images/models/T4_benchmark/t4.fp32.bs4.mobile_trt.params.png)
-目前PaddleClas开源的的移动端系列的预训练模型一共有32个,其指标如图所示。从图片可以看出,越新的轻量级模型往往有更优的表现,MobileNetV3代表了目前最新的轻量级神经网络结构。在MobileNetV3中,作者为了获得更高的精度,在global-avg-pooling后使用了1x1的卷积。该操作大幅提升了参数量但对计算量影响不大,所以如果从存储角度评价模型的优异程度,MobileNetV3优势不是很大,但由于其更小的计算量,使得其有更快的推理速度。此外,我们模型库中的ssld蒸馏模型表现优异,从各个考量角度下,都刷新了当前轻量级模型的精度。由于MobileNetV3模型结构复杂,分支较多,对GPU并不友好,GPU预测速度不如MobileNetV1。
+目前PaddleClas开源的的移动端系列的预训练模型一共有35个,其指标如图所示。从图片可以看出,越新的轻量级模型往往有更优的表现,MobileNetV3代表了目前主流的轻量级神经网络结构。在MobileNetV3中,作者为了获得更高的精度,在global-avg-pooling后使用了1x1的卷积。该操作大幅提升了参数量但对计算量影响不大,所以如果从存储角度评价模型的优异程度,MobileNetV3优势不是很大,但由于其更小的计算量,使得其有更快的推理速度。此外,我们模型库中的ssld蒸馏模型表现优异,从各个考量角度下,都刷新了当前轻量级模型的精度。由于MobileNetV3模型结构复杂,分支较多,对GPU并不友好,GPU预测速度不如MobileNetV1。GhostNet于2020年提出,通过引入ghost的网络设计理念,大大降低了计算量和参数量,同时在精度上也超过前期最高的MobileNetV3网络结构。
## 精度、FLOPS和参数量
@@ -57,6 +59,9 @@ MobileNetV3是Google于2019年提出的一种基于NAS的新的轻量级网络
| ShuffleNetV2_x1_5 | 0.716 | 0.902 | 0.726 | | 0.580 | 3.470 |
| ShuffleNetV2_x2_0 | 0.732 | 0.912 | 0.749 | | 1.120 | 7.320 |
| ShuffleNetV2_swish | 0.700 | 0.892 | | | 0.290 | 2.260 |
+| GhostNet_x0_5 | 0.668 | 0.869 | 0.662 | 0.866 | 0.041 | 2.600 |
+| GhostNet_x1_0 | 0.740 | 0.916 | 0.739 | 0.914 | 0.147 | 5.200 |
+| GhostNet_x1_3 | 0.757 | 0.925 | 0.757 | 0.927 | 0.220 | 7.300 |
## 基于SD855的预测速度和存储大小
diff --git a/docs/zh_CN/models/models_intro.md b/docs/zh_CN/models/models_intro.md
index 3ed9a4b1aa804d06d4bf65ae0a81cb2e365915c2..c2bd273fba93d2ded171e2402f9be1b7a8135959 100644
--- a/docs/zh_CN/models/models_intro.md
+++ b/docs/zh_CN/models/models_intro.md
@@ -93,6 +93,10 @@ python tools/infer/predict.py \
- [ShuffleNetV2_x1_5](https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_x1_5_pretrained.tar)
- [ShuffleNetV2_x2_0](https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_x2_0_pretrained.tar)
- [ShuffleNetV2_swish](https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_swish_pretrained.tar)
+ - GhostNet系列[[23](#ref23)]([论文地址](https://arxiv.org/pdf/1911.11907.pdf))
+ - [GhostNet_x0_5](https://paddle-imagenet-models-name.bj.bcebos.com/GhostNet_x0_5_pretrained.pdparams)
+ - [GhostNet_x1_0](https://paddle-imagenet-models-name.bj.bcebos.com/GhostNet_x1_0_pretrained.pdparams)
+ - [GhostNet_x1_3](https://paddle-imagenet-models-name.bj.bcebos.com/GhostNet_x1_3_pretrained.pdparams)
- SEResNeXt与Res2Net系列
@@ -254,3 +258,5 @@ python tools/infer/predict.py \
[21] Redmon J, Divvala S, Girshick R, et al. You only look once: Unified, real-time object detection[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 779-788.
[22] Ding X, Guo Y, Ding G, et al. Acnet: Strengthening the kernel skeletons for powerful cnn via asymmetric convolution blocks[C]//Proceedings of the IEEE International Conference on Computer Vision. 2019: 1911-1920.
+
+[23] Han K, Wang Y, Tian Q, et al. GhostNet: More features from cheap operations[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020: 1580-1589.
diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py
index bc7d7593f2eca994fc1d0133697a9f428b27a29a..8b288ed099acd9a849366ad138294b64876a5700 100644
--- a/ppcls/modeling/architectures/__init__.py
+++ b/ppcls/modeling/architectures/__init__.py
@@ -18,6 +18,7 @@ from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
from .googlenet import GoogLeNet
from .vgg import VGG11, VGG13, VGG16, VGG19
+from .regnet import RegNetX_200MF, RegNetX_4GF, RegNetX_32GF, RegNetY_200MF, RegNetY_4GF, RegNetY_32GF
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .resnet_vc import ResNet50_vc, ResNet101_vc, ResNet152_vc
from .resnet_vd import ResNet18_vd, ResNet34_vd, ResNet50_vd, ResNet101_vd, ResNet152_vd, ResNet200_vd
@@ -37,6 +38,9 @@ from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
from .darknet import DarkNet53
from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl
from .efficientnet import EfficientNet, EfficientNetB0, EfficientNetB0_small, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
+
+from .efficientnetlite import EfficientNetLite, EfficientNetLite0, EfficientNetLite1, EfficientNetLite2, EfficientNetLite4
+
from .res2net import Res2Net50_48w_2s, Res2Net50_26w_4s, Res2Net50_14w_8s, Res2Net50_26w_6s, Res2Net50_26w_8s, Res2Net101_26w_4s, Res2Net152_26w_4s
from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_14w_8s, Res2Net50_vd_26w_6s, Res2Net50_vd_26w_8s, Res2Net101_vd_26w_4s, Res2Net152_vd_26w_4s, Res2Net200_vd_26w_4s
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
diff --git a/ppcls/modeling/architectures/efficientnetlite.py b/ppcls/modeling/architectures/efficientnetlite.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b66aef3cd4ea34c595032e134242e1eb89ba9ac
--- /dev/null
+++ b/ppcls/modeling/architectures/efficientnetlite.py
@@ -0,0 +1,627 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+import math
+import copy
+
+import paddle.fluid as fluid
+
+from .layers import conv2d, init_batch_norm_layer, init_fc_layer
+
+__all__ = [
+ 'EfficientNetLite', 'EfficientNetLite0', 'EfficientNetLite1',
+ 'EfficientNetLite2', 'EfficientNetLite3', 'EfficientNetLite4'
+]
+
+GlobalParams = collections.namedtuple('GlobalParams', [
+ 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes',
+ 'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth',
+ 'drop_connect_rate', 'fix_head_stem', 'relu_fn', 'local_pooling'
+])
+
+BlockArgs = collections.namedtuple('BlockArgs', [
+ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
+ 'expand_ratio', 'id_skip', 'stride', 'se_ratio'
+])
+
+GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
+BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
+
+
+def efficientnet_lite_params(model_name):
+ """ Map EfficientNet model name to parameter coefficients. """
+ params_dict = {
+ # Coefficients: width,depth,resolution,dropout
+ 'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
+ 'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
+ 'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
+ 'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
+ 'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
+ }
+ return params_dict[model_name]
+
+
+def efficientnet_lite(width_coefficient=None,
+ depth_coefficient=None,
+ dropout_rate=0.2,
+ drop_connect_rate=0.2):
+ """ Get block arguments according to parameter and coefficients. """
+ blocks_args = [
+ 'r1_k3_s11_e1_i32_o16_se0.25',
+ 'r2_k3_s22_e6_i16_o24_se0.25',
+ 'r2_k5_s22_e6_i24_o40_se0.25',
+ 'r3_k3_s22_e6_i40_o80_se0.25',
+ 'r3_k5_s11_e6_i80_o112_se0.25',
+ 'r4_k5_s22_e6_i112_o192_se0.25',
+ 'r1_k3_s11_e6_i192_o320_se0.25',
+ ]
+ blocks_args = BlockDecoder.decode(blocks_args)
+
+ global_params = GlobalParams(
+ batch_norm_momentum=0.99,
+ batch_norm_epsilon=1e-3,
+ dropout_rate=dropout_rate,
+ drop_connect_rate=drop_connect_rate,
+ num_classes=1000,
+ width_coefficient=width_coefficient,
+ depth_coefficient=depth_coefficient,
+ depth_divisor=8,
+ min_depth=None,
+ # FOR LITE, use relu6 for easier quantization
+ relu_fn=True,
+ # FOR LITE, Don't scale in Lite model
+ fix_head_stem=True,
+ # FOR LITE,
+ local_pooling=True)
+
+ return blocks_args, global_params
+
+
+def get_model_params(model_name, override_params):
+ """ Get the block args and global params for a given model """
+ if model_name.startswith('efficientnet-lite'):
+ w, d, _, p = efficientnet_lite_params(model_name)
+ blocks_args, global_params = efficientnet_lite(
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p)
+ else:
+ raise NotImplementedError('model name is not pre-defined: %s' %
+ model_name)
+ if override_params:
+ global_params = global_params._replace(**override_params)
+ return blocks_args, global_params
+
+
+def round_filters(filters, global_params, skip=False):
+ """ Calculate and round number of filters based on depth multiplier. """
+ multiplier = global_params.width_coefficient
+ if skip or not multiplier:
+ return filters
+ divisor = global_params.depth_divisor
+ min_depth = global_params.min_depth
+ filters *= multiplier
+ min_depth = min_depth or divisor
+ new_filters = max(min_depth,
+ int(filters + divisor / 2) // divisor * divisor)
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
+ new_filters += divisor
+ return int(new_filters)
+
+
+def round_repeats(repeats, global_params, skip=False):
+ """ Round number of filters based on depth multiplier. """
+ multiplier = global_params.depth_coefficient
+ if skip or not multiplier:
+ return repeats
+ return int(math.ceil(multiplier * repeats))
+
+
+class EfficientNetLite():
+ def __init__(
+ self,
+ name='lite0',
+ padding_type='SAME',
+ override_params=None,
+ is_test=False,
+ # For Lite, Don't use SE
+ use_se=False):
+ valid_names = ['lite' + str(i) for i in range(5)]
+ assert name in valid_names, 'efficientlite name should be in b0~b7'
+ model_name = 'efficientnet-' + name
+ self._blocks_args, self._global_params = get_model_params(
+ model_name, override_params)
+ print("global_params", self._global_params)
+ self._bn_mom = self._global_params.batch_norm_momentum
+ self._bn_eps = self._global_params.batch_norm_epsilon
+ self.is_test = is_test
+ self.padding_type = padding_type
+ self.use_se = use_se
+ self._relu_fn = self._global_params.relu_fn
+ self._fix_head_stem = self._global_params.fix_head_stem
+ self.local_pooling = self._global_params.local_pooling
+ # NCHW spatial: HW
+ self._spatial_dims = [2, 3]
+
+ def net(self, input, class_dim=1000, is_test=False):
+
+ conv = self.extract_features(input, is_test=is_test)
+
+ out_channels = round_filters(1280, self._global_params,
+ self._fix_head_stem)
+ conv = self.conv_bn_layer(
+ conv,
+ num_filters=out_channels,
+ filter_size=1,
+ bn_act='relu6' if self._relu_fn else 'swish', # for lite
+ bn_mom=self._bn_mom,
+ bn_eps=self._bn_eps,
+ padding_type=self.padding_type,
+ name='',
+ conv_name='_conv_head',
+ bn_name='_bn1')
+
+ pool = fluid.layers.pool2d(
+ input=conv, pool_type='avg', global_pooling=True, use_cudnn=False)
+
+ if self._global_params.dropout_rate:
+ pool = fluid.layers.dropout(
+ pool,
+ self._global_params.dropout_rate,
+ dropout_implementation='upscale_in_train')
+
+ param_attr, bias_attr = init_fc_layer(class_dim, '_fc')
+ out = fluid.layers.fc(pool,
+ class_dim,
+ name='_fc',
+ param_attr=param_attr,
+ bias_attr=bias_attr)
+ return out
+
+ def _drop_connect(self, inputs, prob, is_test):
+ if is_test:
+ return inputs
+ keep_prob = 1.0 - prob
+ inputs_shape = fluid.layers.shape(inputs)
+ random_tensor = keep_prob + fluid.layers.uniform_random(
+ shape=[inputs_shape[0], 1, 1, 1], min=0., max=1.)
+ binary_tensor = fluid.layers.floor(random_tensor)
+ output = inputs / keep_prob * binary_tensor
+ return output
+
+ def _expand_conv_norm(self, inputs, block_args, is_test, name=None):
+ # Expansion phase
+ oup = block_args.input_filters * \
+ block_args.expand_ratio # number of output channels
+
+ if block_args.expand_ratio != 1:
+ conv = self.conv_bn_layer(
+ inputs,
+ num_filters=oup,
+ filter_size=1,
+ bn_act=None,
+ bn_mom=self._bn_mom,
+ bn_eps=self._bn_eps,
+ padding_type=self.padding_type,
+ name=name,
+ conv_name=name + '_expand_conv',
+ bn_name='_bn0')
+
+ return conv
+
+ def _depthwise_conv_norm(self, inputs, block_args, is_test, name=None):
+ k = block_args.kernel_size
+ s = block_args.stride
+ if isinstance(s, list) or isinstance(s, tuple):
+ s = s[0]
+ oup = block_args.input_filters * \
+ block_args.expand_ratio # number of output channels
+
+ conv = self.conv_bn_layer(
+ inputs,
+ num_filters=oup,
+ filter_size=k,
+ stride=s,
+ num_groups=oup,
+ bn_act=None,
+ padding_type=self.padding_type,
+ bn_mom=self._bn_mom,
+ bn_eps=self._bn_eps,
+ name=name,
+ use_cudnn=False,
+ conv_name=name + '_depthwise_conv',
+ bn_name='_bn1')
+
+ return conv
+
+ def _project_conv_norm(self, inputs, block_args, is_test, name=None):
+ final_oup = block_args.output_filters
+ conv = self.conv_bn_layer(
+ inputs,
+ num_filters=final_oup,
+ filter_size=1,
+ bn_act=None,
+ padding_type=self.padding_type,
+ bn_mom=self._bn_mom,
+ bn_eps=self._bn_eps,
+ name=name,
+ conv_name=name + '_project_conv',
+ bn_name='_bn2')
+ return conv
+
+ def conv_bn_layer(
+ self,
+ input,
+ filter_size,
+ num_filters,
+ stride=1,
+ num_groups=1,
+ padding_type="SAME",
+ conv_act=None,
+ bn_act='relu6', # if self._relu_fn else 'swish',
+ use_cudnn=True,
+ use_bn=True,
+ bn_mom=0.9,
+ bn_eps=1e-05,
+ use_bias=False,
+ name=None,
+ conv_name=None,
+ bn_name=None):
+ conv = conv2d(
+ input=input,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ groups=num_groups,
+ act=conv_act,
+ padding_type=padding_type,
+ use_cudnn=use_cudnn,
+ name=conv_name,
+ use_bias=use_bias)
+
+ if use_bn is False:
+ return conv
+ else:
+ bn_name = name + bn_name
+ param_attr, bias_attr = init_batch_norm_layer(bn_name)
+ return fluid.layers.batch_norm(
+ input=conv,
+ act=bn_act,
+ momentum=bn_mom,
+ epsilon=bn_eps,
+ name=bn_name,
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance',
+ param_attr=param_attr,
+ bias_attr=bias_attr)
+
+ def _conv_stem_norm(self, inputs, is_test):
+ out_channels = round_filters(32, self._global_params,
+ self._fix_head_stem)
+ bn = self.conv_bn_layer(
+ inputs,
+ num_filters=out_channels,
+ filter_size=3,
+ stride=2,
+ bn_act=None,
+ bn_mom=self._bn_mom,
+ padding_type=self.padding_type,
+ bn_eps=self._bn_eps,
+ name='',
+ conv_name='_conv_stem',
+ bn_name='_bn0')
+
+ return bn
+
+ def mb_conv_block(self,
+ inputs,
+ block_args,
+ is_test=False,
+ drop_connect_rate=None,
+ name=None):
+ # Expansion and Depthwise Convolution
+ oup = block_args.input_filters * \
+ block_args.expand_ratio # number of output channels
+ has_se = self.use_se and (block_args.se_ratio is not None) and (
+ 0 < block_args.se_ratio <= 1)
+ id_skip = block_args.id_skip # skip connection and drop connect
+ conv = inputs
+ if block_args.expand_ratio != 1:
+ if self._relu_fn:
+ conv = fluid.layers.relu6(
+ self._expand_conv_norm(conv, block_args, is_test, name))
+ else:
+ conv = fluid.layers.swish(
+ self._expand_conv_norm(conv, block_args, is_test, name))
+
+ if self._relu_fn:
+ conv = fluid.layers.relu6(
+ self._depthwise_conv_norm(conv, block_args, is_test, name))
+ else:
+ conv = fluid.layers.swish(
+ self._depthwise_conv_norm(conv, block_args, is_test, name))
+
+ # Squeeze and Excitation
+ if has_se:
+ num_squeezed_channels = max(
+ 1, int(block_args.input_filters * block_args.se_ratio))
+ conv = self.se_block(conv, num_squeezed_channels, oup, name)
+
+ conv = self._project_conv_norm(conv, block_args, is_test, name)
+
+ # Skip connection and drop connect
+ input_filters = block_args.input_filters
+ output_filters = block_args.output_filters
+ if id_skip and \
+ block_args.stride == 1 and \
+ input_filters == output_filters:
+ if drop_connect_rate:
+ conv = self._drop_connect(conv, drop_connect_rate,
+ self.is_test)
+ conv = fluid.layers.elementwise_add(conv, inputs)
+
+ return conv
+
+ def se_block(self, inputs, num_squeezed_channels, oup, name):
+
+ if self.local_pooling:
+ shape = inputs.shape
+ x_squeezed = fluid.layers.pool2d(
+ input=inputs,
+ pool_size=[
+ shape[self._spatial_dims[0]], shape[self._spatial_dims[1]]
+ ],
+ pool_stride=[1, 1],
+ pool_padding='VALID')
+ else:
+ # same as tf: reduce_sum
+ x_squeezed = fluid.layers.pool2d(
+ input=inputs,
+ pool_type='avg',
+ global_pooling=True,
+ use_cudnn=False)
+ x_squeezed = conv2d(
+ x_squeezed,
+ num_filters=num_squeezed_channels,
+ filter_size=1,
+ use_bias=True,
+ padding_type=self.padding_type,
+ act='relu6' if self._relu_fn else 'swish',
+ name=name + '_se_reduce')
+ x_squeezed = conv2d(
+ x_squeezed,
+ num_filters=oup,
+ filter_size=1,
+ use_bias=True,
+ padding_type=self.padding_type,
+ name=name + '_se_expand')
+ #se_out = inputs * fluid.layers.sigmoid(x_squeezed)
+ se_out = fluid.layers.elementwise_mul(
+ inputs, fluid.layers.sigmoid(x_squeezed), axis=-1)
+ return se_out
+
+ def extract_features(self, inputs, is_test):
+ """ Returns output of the final convolution layer """
+
+ if self._relu_fn:
+ conv = fluid.layers.relu6(
+ self._conv_stem_norm(
+ inputs, is_test=is_test))
+ else:
+ fluid.layers.swish(self._conv_stem_norm(inputs, is_test=is_test))
+
+ block_args_copy = copy.deepcopy(self._blocks_args)
+ idx = 0
+ block_size = 0
+ for i, block_arg in enumerate(block_args_copy):
+ block_arg = block_arg._replace(
+ input_filters=round_filters(block_arg.input_filters,
+ self._global_params),
+ output_filters=round_filters(block_arg.output_filters,
+ self._global_params),
+ # Lite
+ num_repeat=block_arg.num_repeat if self._fix_head_stem and
+ (i == 0 or i == len(block_args_copy) - 1) else round_repeats(
+ block_arg.num_repeat, self._global_params))
+
+ block_size += 1
+ for _ in range(block_arg.num_repeat - 1):
+ block_size += 1
+
+ for i, block_args in enumerate(self._blocks_args):
+
+ # Update block input and output filters based on depth multiplier.
+ block_args = block_args._replace(
+ input_filters=round_filters(block_args.input_filters,
+ self._global_params),
+ output_filters=round_filters(block_args.output_filters,
+ self._global_params),
+
+ # Lite
+ num_repeat=block_args.num_repeat if self._fix_head_stem and
+ (i == 0 or i == len(self._blocks_args) - 1) else
+ round_repeats(block_args.num_repeat, self._global_params))
+
+ # The first block needs to take care of stride,
+ # and filter size increase.
+ drop_connect_rate = self._global_params.drop_connect_rate
+ if drop_connect_rate:
+ drop_connect_rate *= float(idx) / block_size
+ conv = self.mb_conv_block(conv, block_args, is_test,
+ drop_connect_rate,
+ '_blocks.' + str(idx) + '.')
+
+ idx += 1
+ if block_args.num_repeat > 1:
+ block_args = block_args._replace(
+ input_filters=block_args.output_filters, stride=1)
+ for _ in range(block_args.num_repeat - 1):
+ drop_connect_rate = self._global_params.drop_connect_rate
+ if drop_connect_rate:
+ drop_connect_rate *= float(idx) / block_size
+ conv = self.mb_conv_block(conv, block_args, is_test,
+ drop_connect_rate,
+ '_blocks.' + str(idx) + '.')
+ idx += 1
+
+ return conv
+
+ def shortcut(self, input, data_residual):
+ return fluid.layers.elementwise_add(input, data_residual)
+
+
+class BlockDecoder(object):
+ """
+ Block Decoder, straight from the official TensorFlow repository.
+ """
+
+ @staticmethod
+ def _decode_block_string(block_string):
+ """ Gets a block through a string notation of arguments. """
+ assert isinstance(block_string, str)
+
+ ops = block_string.split('_')
+ options = {}
+ for op in ops:
+ splits = re.split(r'(\d.*)', op)
+ if len(splits) >= 2:
+ key, value = splits[:2]
+ options[key] = value
+
+ # Check stride
+ cond_1 = ('s' in options and len(options['s']) == 1)
+ cond_2 = ((len(options['s']) == 2) and
+ (options['s'][0] == options['s'][1]))
+ assert (cond_1 or cond_2)
+
+ return BlockArgs(
+ kernel_size=int(options['k']),
+ num_repeat=int(options['r']),
+ input_filters=int(options['i']),
+ output_filters=int(options['o']),
+ expand_ratio=int(options['e']),
+ id_skip=('noskip' not in block_string),
+ se_ratio=float(options['se']) if 'se' in options else None,
+ stride=[int(options['s'][0])])
+
+ @staticmethod
+ def _encode_block_string(block):
+ """Encodes a block to a string."""
+ args = [
+ 'r%d' % block.num_repeat, 'k%d' % block.kernel_size, 's%d%d' %
+ (block.strides[0], block.strides[1]), 'e%s' % block.expand_ratio,
+ 'i%d' % block.input_filters, 'o%d' % block.output_filters
+ ]
+ if 0 < block.se_ratio <= 1:
+ args.append('se%s' % block.se_ratio)
+ if block.id_skip is False:
+ args.append('noskip')
+ return '_'.join(args)
+
+ @staticmethod
+ def decode(string_list):
+ """
+ Decode a list of string notations to specify blocks in the network.
+
+ string_list: list of strings, each string is a notation of block
+ return
+ list of BlockArgs namedtuples of block args
+ """
+ assert isinstance(string_list, list)
+ blocks_args = []
+ for block_string in string_list:
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
+ return blocks_args
+
+ @staticmethod
+ def encode(blocks_args):
+ """
+ Encodes a list of BlockArgs to a list of strings.
+
+ :param blocks_args: a list of BlockArgs namedtuples of block args
+ :return: a list of strings, each string is a notation of block
+ """
+ block_strings = []
+ for block in blocks_args:
+ block_strings.append(BlockDecoder._encode_block_string(block))
+ return block_strings
+
+
+def EfficientNetLite0(is_test=False,
+ padding_type='SAME',
+ override_params=None,
+ use_se=True):
+ model = EfficientNetLite(
+ name='lite0',
+ is_test=is_test,
+ padding_type=padding_type,
+ override_params=override_params,
+ use_se=use_se)
+ return model
+
+
+def EfficientNetLite1(is_test=False,
+ padding_type='SAME',
+ override_params=None,
+ use_se=True):
+ model = EfficientNetLite(
+ name='lite1',
+ is_test=is_test,
+ padding_type=padding_type,
+ override_params=override_params,
+ use_se=use_se)
+ return model
+
+
+def EfficientNetLite2(is_test=False,
+ padding_type='SAME',
+ override_params=None,
+ use_se=True):
+ model = EfficientNetLite(
+ name='lite2',
+ is_test=is_test,
+ padding_type=padding_type,
+ override_params=override_params,
+ use_se=use_se)
+ return model
+
+
+def EfficientNetLite3(is_test=False,
+ padding_type='SAME',
+ override_params=None,
+ use_se=True):
+ model = EfficientNetLite(
+ name='lite3',
+ is_test=is_test,
+ padding_type=padding_type,
+ override_params=override_params,
+ use_se=use_se)
+ return model
+
+
+def EfficientNetLite4(is_test=False,
+ padding_type='SAME',
+ override_params=None,
+ use_se=True):
+ model = EfficientNetLite(
+ name='lite4',
+ is_test=is_test,
+ padding_type=padding_type,
+ override_params=override_params,
+ use_se=use_se)
+ return model
diff --git a/ppcls/modeling/architectures/ghostnet.py b/ppcls/modeling/architectures/ghostnet.py
index 038e2f39fd5f77dc93d47790c9edb0447191a6ad..a1c22eaaee03f465c1c0a59ee7c5b304a6cc2728 100644
--- a/ppcls/modeling/architectures/ghostnet.py
+++ b/ppcls/modeling/architectures/ghostnet.py
@@ -1,3 +1,17 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -152,7 +166,7 @@ class GhostNet():
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
- #excitation = fluid.layers.clip(x=excitation, min=0, max=1)
+ excitation = fluid.layers.clip(x=excitation, min=0, max=1)
se_scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return se_scale
diff --git a/ppcls/modeling/architectures/layers.py b/ppcls/modeling/architectures/layers.py
index f99103b0516cdeed1d36ecd151bc63cc62a1b182..0546514fe48410f620e4da430bd9af8dc49eb6a6 100644
--- a/ppcls/modeling/architectures/layers.py
+++ b/ppcls/modeling/architectures/layers.py
@@ -1,16 +1,16 @@
-#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from __future__ import absolute_import
from __future__ import division
@@ -242,6 +242,8 @@ def conv2d(input,
conv = fluid.layers.sigmoid(conv, name=name + '_sigmoid')
elif act == 'swish':
conv = fluid.layers.swish(conv, name=name + '_swish')
+ elif act == 'relu6':
+ conv = fluid.layers.relu6(conv, name=name + '_relu6')
elif act == None:
conv = conv
else:
diff --git a/ppcls/modeling/architectures/regnet.py b/ppcls/modeling/architectures/regnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8689e0039adea2378e3190f9da136d3cd616f947
--- /dev/null
+++ b/ppcls/modeling/architectures/regnet.py
@@ -0,0 +1,356 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.param_attr import ParamAttr
+
+__all__ = [
+ "RegNetX_200MF", "RegNetX_4GF", "RegNetX_32GF", "RegNetY_200MF",
+ "RegNetY_4GF", "RegNetY_32GF"
+]
+
+
+class RegNet():
+ def __init__(self, w_a, w_0, w_m, d, group_w, bot_mul, q=8, se_on=False):
+ self.w_a = w_a
+ self.w_0 = w_0
+ self.w_m = w_m
+ self.d = d
+ self.q = q
+ self.group_w = group_w
+ self.bot_mul = bot_mul
+ # Stem type
+ self.stem_type = "simple_stem_in"
+ # Stem width
+ self.stem_w = 32
+ # Block type
+ self.block_type = "res_bottleneck_block"
+ # Stride of each stage
+ self.stride = 2
+ # Squeeze-and-Excitation (RegNetY)
+ self.se_on = se_on
+ self.se_r = 0.25
+
+ def quantize_float(self, f, q):
+ """Converts a float to closest non-zero int divisible by q."""
+ return int(round(f / q) * q)
+
+ def adjust_ws_gs_comp(self, ws, bms, gs):
+ """Adjusts the compatibility of widths and groups."""
+ ws_bot = [int(w * b) for w, b in zip(ws, bms)]
+ gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
+ ws_bot = [
+ self.quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)
+ ]
+ ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
+ return ws, gs
+
+ def get_stages_from_blocks(self, ws, rs):
+ """Gets ws/ds of network at each stage from per block values."""
+ ts = [
+ w != wp or r != rp
+ for w, wp, r, rp in zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
+ ]
+ s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
+ s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
+ return s_ws, s_ds
+
+ def generate_regnet(self, w_a, w_0, w_m, d, q=8):
+ """Generates per block ws from RegNet parameters."""
+ assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
+ ws_cont = np.arange(d) * w_a + w_0
+ ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
+ ws = w_0 * np.power(w_m, ks)
+ ws = np.round(np.divide(ws, q)) * q
+ num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
+ ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
+ return ws, num_stages, max_stage, ws_cont
+
+ def init_weights(self, op_type, filter_size=0, num_channels=0, name=None):
+ if op_type == 'conv':
+ fan_out = num_channels * filter_size * filter_size
+ param_attr = ParamAttr(
+ name=name + "_weights",
+ initializer=fluid.initializer.NormalInitializer(
+ loc=0.0, scale=math.sqrt(2.0 / fan_out)))
+ bias_attr = False
+ elif op_type == 'bn':
+ param_attr = ParamAttr(
+ name=name + "_scale",
+ initializer=fluid.initializer.Constant(0.0))
+ bias_attr = ParamAttr(
+ name=name + "_offset",
+ initializer=fluid.initializer.Constant(0.0))
+ elif op_type == 'final_bn':
+ param_attr = ParamAttr(
+ name=name + "_scale",
+ initializer=fluid.initializer.Constant(1.0))
+ bias_attr = ParamAttr(
+ name=name + "_offset",
+ initializer=fluid.initializer.Constant(0.0))
+ return param_attr, bias_attr
+
+ def net(self, input, class_dim=1000):
+ # Generate RegNet ws per block
+ b_ws, num_s, max_s, ws_cont = self.generate_regnet(
+ self.w_a, self.w_0, self.w_m, self.d, self.q)
+ # Convert to per stage format
+ ws, ds = self.get_stages_from_blocks(b_ws, b_ws)
+ # Generate group widths and bot muls
+ gws = [self.group_w for _ in range(num_s)]
+ bms = [self.bot_mul for _ in range(num_s)]
+ # Adjust the compatibility of ws and gws
+ ws, gws = self.adjust_ws_gs_comp(ws, bms, gws)
+ # Use the same stride for each stage
+ ss = [self.stride for _ in range(num_s)]
+ # Use SE for RegNetY
+ se_r = self.se_r
+
+ # Construct the model
+ # Group params by stage
+ stage_params = list(zip(ds, ws, ss, bms, gws))
+ # Construct the stem
+ conv = self.conv_bn_layer(
+ input=input,
+ num_filters=self.stem_w,
+ filter_size=3,
+ stride=2,
+ padding=1,
+ act='relu',
+ name="stem_conv")
+ # Construct the stages
+ for block, (d, w_out, stride, bm, gw) in enumerate(stage_params):
+ for i in range(d):
+ # Stride apply to the first block of the stage
+ b_stride = stride if i == 0 else 1
+ conv_name = 's' + str(block + 1) + '_b' + str(i +
+ 1) # chr(97 + i)
+ conv = self.bottleneck_block(
+ input=conv,
+ num_filters=w_out,
+ stride=b_stride,
+ bm=bm,
+ gw=gw,
+ se_r=self.se_r,
+ name=conv_name)
+ pool = fluid.layers.pool2d(
+ input=conv, pool_type='avg', global_pooling=True)
+ out = fluid.layers.fc(
+ input=pool,
+ size=class_dim,
+ param_attr=ParamAttr(
+ name="fc_0.w_0",
+ initializer=fluid.initializer.NormalInitializer(
+ loc=0.0, scale=0.01)),
+ bias_attr=ParamAttr(
+ name="fc_0.b_0", initializer=fluid.initializer.Constant(0.0)))
+ return out
+
+ def conv_bn_layer(self,
+ input,
+ num_filters,
+ filter_size,
+ stride=1,
+ groups=1,
+ padding=0,
+ act=None,
+ name=None,
+ final_bn=False):
+ param_attr, bias_attr = self.init_weights(
+ op_type='conv',
+ filter_size=filter_size,
+ num_channels=num_filters,
+ name=name)
+ conv = fluid.layers.conv2d(
+ input=input,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ act=None,
+ param_attr=param_attr,
+ bias_attr=bias_attr,
+ name=name + '.conv2d.output.1')
+ bn_name = name + '_bn'
+ if final_bn:
+ param_attr, bias_attr = self.init_weights(
+ op_type='final_bn', name=bn_name)
+ else:
+ param_attr, bias_attr = self.init_weights(
+ op_type='bn', name=bn_name)
+ return fluid.layers.batch_norm(
+ input=conv,
+ act=act,
+ name=bn_name + '.output.1',
+ param_attr=param_attr,
+ bias_attr=bias_attr,
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance', )
+
+ def shortcut(self, input, ch_out, stride, name):
+ ch_in = input.shape[1]
+ if ch_in != ch_out or stride != 1:
+ return self.conv_bn_layer(
+ input=input,
+ num_filters=ch_out,
+ filter_size=1,
+ stride=stride,
+ padding=0,
+ act=None,
+ name=name)
+ else:
+ return input
+
+ def bottleneck_block(self, input, num_filters, stride, bm, gw, se_r, name):
+ # Compute the bottleneck width
+ w_b = int(round(num_filters * bm))
+ # Compute the number of groups
+ num_gs = w_b // gw
+ conv0 = self.conv_bn_layer(
+ input=input,
+ num_filters=w_b,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name=name + "_branch2a")
+ conv1 = self.conv_bn_layer(
+ input=conv0,
+ num_filters=w_b,
+ filter_size=3,
+ stride=stride,
+ padding=1,
+ groups=num_gs,
+ act='relu',
+ name=name + "_branch2b")
+ # Squeeze-and-Excitation (SE)
+ if self.se_on:
+ w_se = int(round(input.shape[1] * se_r))
+ conv1 = self.squeeze_excitation(
+ input=conv1,
+ num_channels=w_b,
+ reduction_channels=w_se,
+ name=name + "_branch2se")
+
+ conv2 = self.conv_bn_layer(
+ input=conv1,
+ num_filters=num_filters,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act=None,
+ name=name + "_branch2c",
+ final_bn=True)
+
+ short = self.shortcut(
+ input, num_filters, stride, name=name + "_branch1")
+
+ return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
+
+ def squeeze_excitation(self,
+ input,
+ num_channels,
+ reduction_channels,
+ name=None):
+ pool = fluid.layers.pool2d(
+ input=input, pool_size=0, pool_type='avg', global_pooling=True)
+ fan_out = num_channels
+ squeeze = fluid.layers.conv2d(
+ input=pool,
+ num_filters=reduction_channels,
+ filter_size=1,
+ act='relu',
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.NormalInitializer(
+ loc=0.0, scale=math.sqrt(2.0 / fan_out)),
+ name=name + '_sqz_weights'),
+ bias_attr=ParamAttr(name=name + '_sqz_offset'))
+ excitation = fluid.layers.conv2d(
+ input=squeeze,
+ num_filters=num_channels,
+ filter_size=1,
+ act='sigmoid',
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.NormalInitializer(
+ loc=0.0, scale=math.sqrt(2.0 / fan_out)),
+ name=name + '_exc_weights'),
+ bias_attr=ParamAttr(name=name + '_exc_offset'))
+ scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
+ return scale
+
+
+def RegNetX_200MF():
+ model = RegNet(
+ w_a=36.44, w_0=24, w_m=2.49, d=13, group_w=8, bot_mul=1.0, q=8)
+ return model
+
+
+def RegNetX_4GF():
+ model = RegNet(
+ w_a=38.65, w_0=96, w_m=2.43, d=23, group_w=40, bot_mul=1.0, q=8)
+ return model
+
+
+def RegNetX_32GF():
+ model = RegNet(
+ w_a=69.86, w_0=320, w_m=2.0, d=23, group_w=168, bot_mul=1.0, q=8)
+ return model
+
+
+def RegNetY_200MF():
+ model = RegNet(
+ w_a=36.44,
+ w_0=24,
+ w_m=2.49,
+ d=13,
+ group_w=8,
+ bot_mul=1.0,
+ q=8,
+ se_on=True)
+ return model
+
+
+def RegNetY_4GF():
+ model = RegNet(
+ w_a=31.41,
+ w_0=96,
+ w_m=2.24,
+ d=22,
+ group_w=64,
+ bot_mul=1.0,
+ q=8,
+ se_on=True)
+ return model
+
+
+def RegNetY_32GF():
+ model = RegNet(
+ w_a=115.89,
+ w_0=232,
+ w_m=2.53,
+ d=20,
+ group_w=232,
+ bot_mul=1.0,
+ q=8,
+ se_on=True)
+ return model
diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py
index e1712a8a83768f83e6e44e3bdf7fdf8652ffed76..54ee51752a6658972a93a88380556963698c9cf4 100644
--- a/ppcls/utils/config.py
+++ b/ppcls/utils/config.py
@@ -144,9 +144,14 @@ def override(dl, ks, v):
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
- assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
+ #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
+ if not ks[0] in dl:
+ logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
+ assert ks[0] in dl, (
+ '({}) doesn\'t exist in {}, a new dict field is invalid'.
+ format(ks[0], dl))
override(dl[ks[0]], ks[1:], v)
diff --git a/ppcls/utils/pretrained.list b/ppcls/utils/pretrained.list
index 36d70f5a24624dad507d51e3bc7c77eeb5444e9c..e282d616a8824fdb99dbd620cbb6a743d7efddc6 100644
--- a/ppcls/utils/pretrained.list
+++ b/ppcls/utils/pretrained.list
@@ -119,3 +119,6 @@ VGG19
DarkNet53_ImageNet1k
ResNet50_ACNet_deploy
CSPResNet50_leaky
+GhostNet_x0_5
+GhostNet_x1_0
+GhostNet_x1_3
diff --git a/tools/dali.py b/tools/dali.py
index b018ae79744c91766f7a104796346f22d9ce02ab..e6148f52a601280144185e626c35f1cdb043cee4 100644
--- a/tools/dali.py
+++ b/tools/dali.py
@@ -151,10 +151,15 @@ def build(settings, mode='train'):
file_root = settings.TRAIN.data_dir
bs = settings.TRAIN.batch_size if mode == 'train' else settings.VALID.batch_size
- print(bs, paddle.fluid.core.get_cuda_device_count())
- assert bs % paddle.fluid.core.get_cuda_device_count() == 0, \
+
+ gpu_num = paddle.fluid.core.get_cuda_device_count() if (
+ 'PADDLE_TRAINERS_NUM') and (
+ 'PADDLE_TRAINER_ID'
+ ) not in env else int(env.get('PADDLE_TRAINERS_NUM', 0))
+
+ assert bs % gpu_num == 0, \
"batch size must be multiple of number of devices"
- batch_size = bs // paddle.fluid.core.get_cuda_device_count()
+ batch_size = bs // gpu_num
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]
diff --git a/tools/program.py b/tools/program.py
index 3af811b0025aaff2b270018c23f108aba59861fe..21db1310721bee9b4e0eed80fc7674047afe75c4 100644
--- a/tools/program.py
+++ b/tools/program.py
@@ -456,16 +456,29 @@ def run(dataloader,
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1
if mode == 'eval':
- logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
+ if idx % config.get('print_interval', 10) == 0:
+ logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
+ fetchs_str))
else:
epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx)
- logger.info("{:s} {:s} {:s}".format(
- logger.coloring(epoch_str, "HEADER")
- if idx == 0 else epoch_str,
- logger.coloring(step_str, "PURPLE"),
- logger.coloring(fetchs_str, 'OKGREEN')))
+ # Keep the first 10 batches statistics, They are important for develop
+ if epoch == 0 and idx < 10:
+ logger.info("{:s} {:s} {:s}".format(
+ logger.coloring(epoch_str, "HEADER")
+ if idx == 0 else epoch_str,
+ logger.coloring(step_str, "PURPLE"),
+ logger.coloring(fetchs_str, 'OKGREEN')))
+
+ else:
+ if idx % config.get('print_interval', 10) == 0:
+ logger.info("{:s} {:s} {:s}".format(
+ logger.coloring(epoch_str, "HEADER")
+ if idx == 0 else epoch_str,
+ logger.coloring(step_str, "PURPLE"),
+ logger.coloring(fetchs_str, 'OKGREEN')))
+
if config.get('use_dali'):
dataloader.reset()
diff --git a/tools/run.sh b/tools/run.sh
index 55f2918d91387a88227027656e9ac5e00320d358..5e8043b1205f99dada7448d172f84cd0661df509 100755
--- a/tools/run.sh
+++ b/tools/run.sh
@@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
tools/train.py \
- -c ./configs/ResNet/ResNet50.yaml
+ -c ./configs/ResNet/ResNet50.yaml \
+ -o print_interval=10