From f6f1c3adafcab7ee65055a0f245ff833102ea481 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Sat, 21 Sep 2019 00:16:12 +0800 Subject: [PATCH] Add efficientNet to image classification. (#3358) Add EfficientNet --- PaddleCV/image_classification/README.md | 34 ++ PaddleCV/image_classification/README_en.md | 39 +- PaddleCV/image_classification/ema_clean.py | 47 ++ PaddleCV/image_classification/eval.py | 7 +- PaddleCV/image_classification/infer.py | 3 +- .../image_classification/models/__init__.py | 1 + .../models/efficientnet.py | 453 ++++++++++++++++++ .../image_classification/models/layers.py | 222 +++++++++ PaddleCV/image_classification/reader.py | 28 +- .../scripts/train/EfficientNetB0.sh | 27 ++ PaddleCV/image_classification/train.py | 102 ++-- .../image_classification/utils/__init__.py | 4 +- .../image_classification/utils/autoaugment.py | 243 ++++++++++ .../image_classification/utils/optimizer.py | 54 ++- .../image_classification/utils/utility.py | 158 +++++- 15 files changed, 1363 insertions(+), 59 deletions(-) create mode 100644 PaddleCV/image_classification/ema_clean.py create mode 100644 PaddleCV/image_classification/models/efficientnet.py create mode 100644 PaddleCV/image_classification/models/layers.py create mode 100644 PaddleCV/image_classification/scripts/train/EfficientNetB0.sh create mode 100644 PaddleCV/image_classification/utils/autoaugment.py diff --git a/PaddleCV/image_classification/README.md b/PaddleCV/image_classification/README.md index f69dc88f..1ede86b1 100644 --- a/PaddleCV/image_classification/README.md +++ b/PaddleCV/image_classification/README.md @@ -115,6 +115,8 @@ bash run.sh train 模型名 * **l2_decay**: l2_decay值,默认值: 1e-4 * **momentum_rate**: momentum_rate值,默认值: 0.9 * **step_epochs**: piecewise dacay的decay step,默认值:[30,60,90] +* **decay_epochs**: exponential decay的间隔epoch数, 默认值: 2.4. +* **decay_rate**: exponential decay的下降率, 默认值: 0.97. 数据读取器和预处理配置: @@ -125,6 +127,7 @@ bash run.sh train 模型名 * **crop_size**: 指定裁剪的大小,默认值:224 * **use_mixup**: 是否对数据进行mixup处理,默认值: False * **mixup_alpha**: 指定mixup处理时的alpha值,默认值: 0.2 +* **use_aa**: 是否对数据进行auto augment处理. 默认值: False. * **reader_thread**: 多线程reader的线程数量,默认值: 8 * **reader_buf_size**: 多线程reader的buf_size, 默认值: 2048 * **interpolation**: 插值方法, 默认值:None @@ -138,6 +141,9 @@ bash run.sh train 模型名 * **use_label_smoothing**: 是否对数据进行label smoothing处理,默认值: False * **label_smoothing_epsilon**: label_smoothing的epsilon, 默认值:0.1 * **random_seed**: 随机数种子, 默认值: 1000 +* **padding_type**: efficientNet中卷积操作的padding方式, 默认值: "SAME". +* **use_ema**: 是否在更新模型参数时使用ExponentialMovingAverage. 默认值: False. +* **ema_decay**: ExponentialMovingAverage的decay rate. 默认值: 0.9999. **数据读取器说明:** 数据读取器定义在```reader.py```文件中,现在默认基于cv2的数据读取器, 在[训练阶段](#模型训练),默认采用的增广方式是随机裁剪与水平翻转, 而在[模型评估](#模型评估)与[模型预测](#模型预测)阶段用的默认方式是中心裁剪。当前支持的数据增广方式有: @@ -147,6 +153,7 @@ bash run.sh train 模型名 * 中心裁剪 * 长宽调整 * 水平翻转 +* 自动增广 ### 参数微调 @@ -170,6 +177,20 @@ python eval.py \ ``` 注意:根据具体模型和任务添加并调整其他参数 +### 指数滑动平均的模型评估 + +注意: 如果你使用指数滑动平均来训练模型(--use_ema=True),并且想要评估指数滑动平均后的模型,需要使用ema_clean.py将训练中保存下来的ema模型名字转换成原始模型参数的名字。 + +``` +python ema_clean.py \ + --ema_model_dir=your_ema_model_dir \ + --cleaned_model_dir=your_cleaned_model_dir + +python eval.py \ + --model=model_name \ + --pretrained_model=your_cleaned_model_dir +``` + ### 模型预测 模型预测(Infer)可以获取一个模型的预测分数或者图像的特征,可以下载[已发布模型及其性能](#已发布模型及其性能)并且设置```path_to_pretrain_model```为模型所在路径。运行如下的命令获得预测结果: @@ -361,6 +382,17 @@ PaddlePaddle/Models ImageClassification 支持自定义数据 |[ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x48d_wsl_pretrained.tar) | 85.37% | 97.69% | 161.722 | | |[Fix_ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/Fix_ResNeXt101_32x48d_wsl_pretrained.tar) | 86.26% | 97.97% | 236.091 | | +### EfficientNet Series +|Model | Top-1 | Top-5 | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: |:-: | +|[EfficientNetB0](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB0_pretrained.tar) | 77.38% | 93.31% | 10.303 | 4.334 | +|[EfficientNetB1](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB1_pretrained.tar) | 79.15% | 94.41% | 15.626 | 6.502 | +|[EfficientNetB2](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB2_pretrained.tar) | 79.85% | 94.74% | 17.847 | 7.558 | +|[EfficientNetB3](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB3_pretrained.tar) | 81.15% | 95.41% | 25.993 | 10.937 | +|[EfficientNetB4](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB4_pretrained.tar) | 82.85% | 96.23% | 47.734 | 18.536 | +|[EfficientNetB5](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB5_pretrained.tar) | 83.62% | 96.72% | 88.578 | 32.102 | +|[EfficientNetB6](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB6_pretrained.tar) | 84.00% | 96.88% | 138.670 | 51.059 | +|[EfficientNetB7](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB7_pretrained.tar) | 84.30% | 96.89% | 234.364 | 82.107 | ## FAQ @@ -400,6 +432,7 @@ PaddlePaddle/Models ImageClassification 支持自定义数据 - SqueezeNet: [SQUEEZENET: ALEXNET-LEVEL ACCURACY WITH 50X FEWER PARAMETERS AND <0.5MB MODEL SIZE](https://arxiv.org/abs/1602.07360), Forrest N. Iandola, Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, Kurt Keutzer - ResNeXt101_wsl: [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), Dhruv Mahajan, Ross Girshick, Vignesh Ramanathan, Kaiming He, Manohar Paluri, Yixuan Li, Ashwin Bharambe, Laurens van der Maaten - Fix_ResNeXt101_wsl: [Fixing the train-test resolution discrepancy](https://arxiv.org/abs/1906.06423), Hugo Touvron, Andrea Vedaldi, Matthijs Douze, Herve ́ Je ́gou +- EfficientNet: [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946), Mingxing Tan, Quoc V. Le ## 版本更新 - 2018/12/03 **Stage1**: 更新AlexNet,ResNet50,ResNet101,MobileNetV1 @@ -412,6 +445,7 @@ PaddlePaddle/Models ImageClassification 支持自定义数据 - 2019/07/19 **Stage6**: 更新ShuffleNetV2_x0_25,ShuffleNetV2_x0_33,ShuffleNetV2_x0_5,ShuffleNetV2_x1_0,ShuffleNetV2_x1_5,ShuffleNetV2_x2_0,MobileNetV2_x0_25,MobileNetV2_x1_5,MobileNetV2_x2_0,ResNeXt50_vd_64x4d,ResNeXt101_32x4d,ResNeXt152_32x4d - 2019/08/01 **Stage7**: 更新DarkNet53,DenseNet121,Densenet161,DenseNet169,DenseNet201,DenseNet264,SqueezeNet1_0,SqueezeNet1_1,ResNeXt50_vd_32x4d,ResNeXt152_64x4d,ResNeXt101_32x8d_wsl,ResNeXt101_32x16d_wsl,ResNeXt101_32x32d_wsl,ResNeXt101_32x48d_wsl,Fix_ResNeXt101_32x48d_wsl - 2019/09/11 **Stage8**: 更新ResNet18_vd,ResNet34_vd,MobileNetV1_x0_25,MobileNetV1_x0_5,MobileNetV1_x0_75,MobileNetV2_x0_75,MobilenNetV3_small_x1_0,DPN68,DPN92,DPN98,DPN107,DPN131,ResNeXt101_vd_32x4d,ResNeXt152_vd_64x4d,Xception65,Xception71,Xception41_deeplab,Xception65_deeplab,SE_ResNet50_vd +- 2019/09/20 更新EfficientNet ## 如何贡献代码 diff --git a/PaddleCV/image_classification/README_en.md b/PaddleCV/image_classification/README_en.md index c680c28b..4d8cae3c 100644 --- a/PaddleCV/image_classification/README_en.md +++ b/PaddleCV/image_classification/README_en.md @@ -106,7 +106,9 @@ Solver and hyperparameters: * **lr**: initialized learning rate. Default: 0.1. * **l2_decay**: L2_decay parameter. Default: 1e-4. * **momentum_rate**: momentum_rate. Default: 0.9. -* **step_epochs**: decay step of piecewise step, Default: [30,60,90] +* **step_epochs**: decay step of piecewise step, Default: [30,60,90]. +* **decay_epochs**: decay epoch of exponential decay, Default: 2.4. +* **decay_rate**: decay rate of exponential decay, Default: 0.97. Reader and preprocess: @@ -117,6 +119,7 @@ Reader and preprocess: * **crop_size**: the crop size, Default: 224. * **use_mixup**: whether to use mixup data processing or not. Default:False. * **mixup_alpha**: the mixup_alpha parameter. Default: 0.2. +* **use_aa**: whether to use auto augment data processing or not. Default:False. * **reader_thread**: the number of threads in multi thread reader, Default: 8 * **reader_buf_size**: the buff size of multi thread reader, Default: 2048 * **interpolation**: interpolation method, Default: None @@ -129,7 +132,10 @@ Switch: * **use_gpu**: whether to use GPU or not. Default: True. * **use_label_smoothing**: whether to use label_smoothing or not. Default:False. * **label_smoothing_epsilon**: the label_smoothing_epsilon. Default:0.1. -* **random_seed**: random seed for debugging, Default: 1000 +* **random_seed**: random seed for debugging, Default: 1000. +* **padding_type**: padding type of convolution for efficientNet, Default: "SAME". +* **use_ema**: whether to use ExponentialMovingAverage or not. Default: False. +* **ema_decay**: the value of ExponentialMovingAverage decay rate. Default: 0.9999. **data reader introduction:** Data reader is defined in ```reader.py```, default reader is implemented by opencv. In the [Training](#training) Stage, random crop and flipping are applied, while center crop is applied in the [Evaluation](#evaluation) and [Inference](#inference) stages. Supported data augmentation includes: @@ -139,6 +145,7 @@ Switch: * center crop * resize * flipping +* auto augment ### Finetuning @@ -164,6 +171,20 @@ python eval.py \ Note: Add and adjust other parameters accroding to specific models and tasks. +### ExponentialMovingAverage Evaluation + +Note: if you train model with flag use_ema, and you want to evaluate your ExponentialMovingAverage model, you should clean your saved model first. + +``` +python ema_clean.py \ + --ema_model_dir=your_ema_model_dir \ + --cleaned_model_dir=your_cleaned_model_dir + +python eval.py \ + --model=model_name \ + --pretrained_model=your_cleaned_model_dir +``` + ### Inference **some Inference stage unique parameters** @@ -343,6 +364,18 @@ Pretrained models can be downloaded by clicking related model names. |[ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x48d_wsl_pretrained.tar) | 85.37% | 97.69% | 161.722 | | |[Fix_ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/Fix_ResNeXt101_32x48d_wsl_pretrained.tar) | 86.26% | 97.97% | 236.091 | | +### EfficientNet Series +|Model | Top-1 | Top-5 | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) | +|- |:-: |:-: |:-: |:-: | +|[EfficientNetB0](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB0_pretrained.tar) | 77.38% | 93.31% | 10.303 | 4.334 | +|[EfficientNetB1](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB1_pretrained.tar) | 79.15% | 94.41% | 15.626 | 6.502 | +|[EfficientNetB2](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB2_pretrained.tar) | 79.85% | 94.74% | 17.847 | 7.558 | +|[EfficientNetB3](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB3_pretrained.tar) | 81.15% | 95.41% | 25.993 | 10.937 | +|[EfficientNetB4](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB4_pretrained.tar) | 82.85% | 96.23% | 47.734 | 18.536 | +|[EfficientNetB5](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB5_pretrained.tar) | 83.62% | 96.72% | 88.578 | 32.102 | +|[EfficientNetB6](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB6_pretrained.tar) | 84.00% | 96.88% | 138.670 | 51.059 | +|[EfficientNetB7](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB7_pretrained.tar) | 84.30% | 96.89% | 234.364 | 82.107 | + ## FAQ **Q:** How to solve this problem when I try to train a 6-classes dataset with indicating pretrained_model parameter ? @@ -374,6 +407,7 @@ Enforce failed. Expected x_dims[1] == labels_dims[1], but received x_dims[1]:100 - SqueezeNet: [SQUEEZENET: ALEXNET-LEVEL ACCURACY WITH 50X FEWER PARAMETERS AND <0.5MB MODEL SIZE](https://arxiv.org/abs/1602.07360), Forrest N. Iandola, Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, Kurt Keutzer - ResNeXt101_wsl: [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), Dhruv Mahajan, Ross Girshick, Vignesh Ramanathan, Kaiming He, Manohar Paluri, Yixuan Li, Ashwin Bharambe, Laurens van der Maaten - Fix_ResNeXt101_wsl: [Fixing the train-test resolution discrepancy](https://arxiv.org/abs/1906.06423), Hugo Touvron, Andrea Vedaldi, Matthijs Douze, Herve ́ Je ́gou +- EfficientNet: [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946), Mingxing Tan, Quoc V. Le ## Update @@ -387,6 +421,7 @@ Enforce failed. Expected x_dims[1] == labels_dims[1], but received x_dims[1]:100 - 2019/07/19 **Stage6**: Update ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, MobileNetV2_x0_25, MobileNetV2_x1_5, MobileNetV2_x2_0, ResNeXt50_vd_64x4d, ResNeXt101_32x4d, ResNeXt152_32x4d - 2019/08/01 **Stage7**: Update DarkNet53, DenseNet121. Densenet161, DenseNet169, DenseNet201, DenseNet264, SqueezeNet1_0, SqueezeNet1_1, ResNeXt50_vd_32x4d, ResNeXt152_64x4d, ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl - 2019/09/11 **Stage8**: Update ResNet18_vd,ResNet34_vd,MobileNetV1_x0_25,MobileNetV1_x0_5,MobileNetV1_x0_75,MobileNetV2_x0_75,MobilenNetV3_small_x1_0,DPN68,DPN92,DPN98,DPN107,DPN131,ResNeXt101_vd_32x4d,ResNeXt152_vd_64x4d,Xception65,Xception71,Xception41_deeplab,Xception65_deeplab,SE_ResNet50_vd +- 2019/09/20 Update EfficientNet ## Contribute diff --git a/PaddleCV/image_classification/ema_clean.py b/PaddleCV/image_classification/ema_clean.py new file mode 100644 index 00000000..39f1fe7e --- /dev/null +++ b/PaddleCV/image_classification/ema_clean.py @@ -0,0 +1,47 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import argparse +import functools +import shutil +from utils import print_arguments, add_arguments + +parser = argparse.ArgumentParser(description=__doc__) +# yapf: disable +add_arg = functools.partial(add_arguments, argparser=parser) +add_arg('ema_model_dir', str, None, "The directory of model which use ExponentialMovingAverage to train") +add_arg('cleaned_model_dir', str, None, "The directory of cleaned model") +# yapf: enable + +def main(): + args = parser.parse_args() + print_arguments(args) + if not os.path.exists(args.cleaned_model_dir): + os.makedirs(args.cleaned_model_dir) + + items = os.listdir(args.ema_model_dir) + for item in items: + if item.find('ema') > -1: + item_clean = item.replace('_ema_0', '') + shutil.copyfile(os.path.join(args.ema_model_dir, item), + os.path.join(args.cleaned_model_dir, item_clean)) + + +if __name__ == '__main__': + main() diff --git a/PaddleCV/image_classification/eval.py b/PaddleCV/image_classification/eval.py index 7869bf52..31a1e64f 100644 --- a/PaddleCV/image_classification/eval.py +++ b/PaddleCV/image_classification/eval.py @@ -46,6 +46,8 @@ add_arg('reader_buf_size', int, 2048, "The buf size of multi t parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456, 0.406], help="The mean of input image data") parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data") add_arg('crop_size', int, 224, "The value of crop size") +add_arg('interpolation', int, None, "The interpolation mode") +add_arg('padding_type', str, "SAME", "Padding type of convolution") # yapf: enable @@ -64,7 +66,10 @@ def eval(args): label = fluid.layers.data(name='label', shape=[1], dtype='int64') # model definition - model = models.__dict__[args.model]() + if args.model.startswith('EfficientNet'): + model = models.__dict__[args.model](is_test=True, padding_type=args.padding_type) + else: + model = models.__dict__[args.model]() if args.model == "GoogLeNet": out0, out1, out2 = model.net(input=image, class_dim=args.class_dim) diff --git a/PaddleCV/image_classification/infer.py b/PaddleCV/image_classification/infer.py index 4d61805d..93e06cb1 100644 --- a/PaddleCV/image_classification/infer.py +++ b/PaddleCV/image_classification/infer.py @@ -47,7 +47,8 @@ parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456 parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data") add_arg('crop_size', int, 224, "The value of crop size") add_arg('topk', int, 1, "topk") -add_arg('label_path', str, "./utils/tools/readable_label.txt", "readable label filepath") +add_arg('label_path', str, "./utils/tools/readable_label.txt", "readable label filepath") +add_arg('interpolation', int, None, "The interpolation mode") # yapf: enable diff --git a/PaddleCV/image_classification/models/__init__.py b/PaddleCV/image_classification/models/__init__.py index 48f9725d..9ebb3d56 100644 --- a/PaddleCV/image_classification/models/__init__.py +++ b/PaddleCV/image_classification/models/__init__.py @@ -37,3 +37,4 @@ from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseN from .squeezenet import SqueezeNet1_0, SqueezeNet1_1 from .darknet import DarkNet53 from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl +from .efficientnet import EfficientNet, EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7 diff --git a/PaddleCV/image_classification/models/efficientnet.py b/PaddleCV/image_classification/models/efficientnet.py new file mode 100644 index 00000000..8faea841 --- /dev/null +++ b/PaddleCV/image_classification/models/efficientnet.py @@ -0,0 +1,453 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +import collections +import re +import math +import copy +from .layers import conv2d, init_batch_norm_layer, init_fc_layer + + +__all__ = ['EfficientNet', 'EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', 'EfficientNetB4', + 'EfficientNetB5', 'EfficientNetB6', 'EfficientNetB7'] + +GlobalParams = collections.namedtuple('GlobalParams', [ + 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', + 'num_classes', 'width_coefficient', 'depth_coefficient', + 'depth_divisor', 'min_depth', 'drop_connect_rate', ]) + +BlockArgs = collections.namedtuple('BlockArgs', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) + +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + + +def efficientnet_params(model_name): + """ Map EfficientNet model name to parameter coefficients. """ + params_dict = { + # Coefficients: width,depth,resolution,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + } + return params_dict[model_name] + + +def efficientnet(width_coefficient=None, depth_coefficient=None, + dropout_rate=0.2, drop_connect_rate=0.2): + """ Get block arguments according to parameter and coefficients. """ + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + dropout_rate=dropout_rate, + drop_connect_rate=drop_connect_rate, + num_classes=1000, + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + depth_divisor=8, + min_depth=None + ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """ Get the block args and global params for a given model """ + if model_name.startswith('efficientnet'): + w, d, _, p = efficientnet_params(model_name) + blocks_args, global_params = efficientnet(width_coefficient=w, depth_coefficient=d, dropout_rate=p) + else: + raise NotImplementedError('model name is not pre-defined: %s' % model_name) + if override_params: + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +def round_filters(filters, global_params): + """ Calculate and round number of filters based on depth multiplier. """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """ Round number of filters based on depth multiplier. """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +class EfficientNet(): + def __init__(self, name='b0', padding_type='SAME', override_params=None, is_test=False): + valid_names = ['b' + str(i) for i in range(8)] + assert name in valid_names, 'efficient name should be in b0~b7' + model_name = 'efficientnet-' + name + self._blocks_args, self._global_params = get_model_params(model_name, override_params) + self._bn_mom = self._global_params.batch_norm_momentum + self._bn_eps = self._global_params.batch_norm_epsilon + self.is_test = is_test + self.padding_type = padding_type + + def net(self, input, class_dim=1000, is_test=False): + + conv = self.extract_features(input, is_test=is_test) + + out_channels = round_filters(1280, self._global_params) + conv = self.conv_bn_layer(conv, + num_filters=out_channels, + filter_size=1, + bn_act='swish', + bn_mom=self._bn_mom, + bn_eps=self._bn_eps, + padding_type=self.padding_type, + name='', + conv_name='_conv_head', + bn_name='_bn1') + + pool = fluid.layers.pool2d(input=conv, pool_type='avg', global_pooling=True, use_cudnn=False) + + if self._global_params.dropout_rate: + pool = fluid.layers.dropout(pool, self._global_params.dropout_rate, dropout_implementation='upscale_in_train') + + param_attr, bias_attr = init_fc_layer(class_dim, '_fc') + out = fluid.layers.fc(pool, class_dim, name='_fc', param_attr=param_attr, bias_attr=bias_attr) + return out + + def _drop_connect(self, inputs, prob, is_test): + if is_test: + return inputs + keep_prob = 1.0 - prob + random_tensor = keep_prob + fluid.layers.uniform_random_batch_size_like(inputs, [-1, 1, 1, 1], min=0., max=1.) + binary_tensor = fluid.layers.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + def _expand_conv_norm(self, inputs, block_args, is_test, name=None): + # Expansion phase + oup = block_args.input_filters * block_args.expand_ratio # number of output channels + + if block_args.expand_ratio != 1: + conv = self.conv_bn_layer(inputs, + num_filters=oup, + filter_size=1, + bn_act=None, + bn_mom=self._bn_mom, + bn_eps=self._bn_eps, + padding_type=self.padding_type, + name=name, + conv_name=name + '_expand_conv', + bn_name='_bn0') + + return conv + + def _depthwise_conv_norm(self, inputs, block_args, is_test, name=None): + k = block_args.kernel_size + s = block_args.stride + if isinstance(s, list) or isinstance(s, tuple): + s = s[0] + oup = block_args.input_filters * block_args.expand_ratio # number of output channels + + conv = self.conv_bn_layer(inputs, + num_filters=oup, + filter_size=k, + stride=s, + num_groups=oup, + bn_act=None, + padding_type=self.padding_type, + bn_mom=self._bn_mom, + bn_eps=self._bn_eps, + name=name, + use_cudnn=False, + conv_name=name + '_depthwise_conv', + bn_name='_bn1') + + return conv + + def _project_conv_norm(self, inputs, block_args, is_test, name=None): + final_oup = block_args.output_filters + conv = self.conv_bn_layer(inputs, + num_filters=final_oup, + filter_size=1, + bn_act=None, + padding_type=self.padding_type, + bn_mom=self._bn_mom, + bn_eps=self._bn_eps, + name=name, + conv_name=name + '_project_conv', + bn_name='_bn2') + return conv + + def conv_bn_layer(self, input, filter_size, num_filters, stride=1, num_groups=1, padding_type="SAME", conv_act=None, + bn_act='swish', use_cudnn=True, use_bn=True, bn_mom=0.9, bn_eps=1e-05, use_bias=False, name=None, + conv_name=None, bn_name=None): + conv = conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + groups=num_groups, + act=conv_act, + padding_type=padding_type, + use_cudnn=use_cudnn, + name=conv_name, + use_bias=use_bias) + + if use_bn == False: + return conv + else: + bn_name = name + bn_name + param_attr, bias_attr = init_batch_norm_layer(bn_name) + return fluid.layers.batch_norm(input=conv, + act=bn_act, + momentum=bn_mom, + epsilon=bn_eps, + name=bn_name, + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + param_attr=param_attr, + bias_attr=bias_attr) + + def _conv_stem_norm(self, inputs, is_test): + out_channels = round_filters(32, self._global_params) + bn = self.conv_bn_layer(inputs, num_filters=out_channels, filter_size=3, stride=2, bn_act=None, + bn_mom=self._bn_mom, padding_type=self.padding_type, + bn_eps=self._bn_eps, name='', conv_name='_conv_stem', bn_name='_bn0') + + return bn + + def mb_conv_block(self, inputs, block_args, is_test=False, drop_connect_rate=None, name=None): + # Expansion and Depthwise Convolution + oup = block_args.input_filters * block_args.expand_ratio # number of output channels + has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1) + id_skip = block_args.id_skip # skip connection and drop connect + conv = inputs + if block_args.expand_ratio != 1: + conv = fluid.layers.swish(self._expand_conv_norm(conv, block_args, is_test, name)) + + conv = fluid.layers.swish(self._depthwise_conv_norm(conv, block_args, is_test, name)) + + # Squeeze and Excitation + if has_se: + num_squeezed_channels = max(1, int(block_args.input_filters * block_args.se_ratio)) + conv = self.se_block(conv, num_squeezed_channels, oup, name) + + conv = self._project_conv_norm(conv, block_args, is_test, name) + + # Skip connection and drop connect + input_filters, output_filters = block_args.input_filters, block_args.output_filters + if id_skip and block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + conv = self._drop_connect(conv, drop_connect_rate, self.is_test) + conv = fluid.layers.elementwise_add(conv, inputs) + + return conv + + def se_block(self, inputs, num_squeezed_channels, oup, name): + x_squeezed = fluid.layers.pool2d( + input=inputs, + pool_type='avg', + global_pooling=True, + use_cudnn=False) + x_squeezed = conv2d(x_squeezed, + num_filters=num_squeezed_channels, + filter_size=1, + use_bias=True, + padding_type=self.padding_type, + act='swish', + name=name + '_se_reduce') + x_squeezed = conv2d(x_squeezed, + num_filters=oup, + filter_size=1, + use_bias=True, + padding_type=self.padding_type, + name=name + '_se_expand') + se_out = inputs * fluid.layers.sigmoid(x_squeezed) + return se_out + + def extract_features(self, inputs, is_test): + """ Returns output of the final convolution layer """ + + conv = fluid.layers.swish(self._conv_stem_norm(inputs, is_test=is_test)) + + block_args_copy = copy.deepcopy(self._blocks_args) + idx = 0 + block_size = 0 + for block_arg in block_args_copy: + block_arg = block_arg._replace( + input_filters=round_filters(block_arg.input_filters, self._global_params), + output_filters=round_filters(block_arg.output_filters, self._global_params), + num_repeat=round_repeats(block_arg.num_repeat, self._global_params) + ) + block_size += 1 + for _ in range(block_arg.num_repeat - 1): + block_size += 1 + + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / block_size + conv = self.mb_conv_block(conv, block_args, is_test, drop_connect_rate, '_blocks.' + str(idx) + '.') + + idx += 1 + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / block_size + conv = self.mb_conv_block(conv, block_args, is_test, drop_connect_rate, '_blocks.' + str(idx) + '.') + idx += 1 + + return conv + + def shortcut(self, input, data_residual): + return fluid.layers.elementwise_add(input, data_residual) + + +class BlockDecoder(object): + """ Block Decoder for readability, straight from the official TensorFlow repository """ + + @staticmethod + def _decode_block_string(block_string): + """ Gets a block through a string notation of arguments. """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + input_filters=int(options['i']), + output_filters=int(options['o']), + expand_ratio=int(options['e']), + id_skip=('noskip' not in block_string), + se_ratio=float(options['se']) if 'se' in options else None, + stride=[int(options['s'][0])]) + + @staticmethod + def _encode_block_string(block): + """Encodes a block to a string.""" + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """ + Decodes a list of string notations to specify blocks inside the network. + + :param string_list: a list of strings, each string is a notation of block + :return: a list of BlockArgs namedtuples of block args + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """ + Encodes a list of BlockArgs to a list of strings. + + :param blocks_args: a list of BlockArgs namedtuples of block args + :return: a list of strings, each string is a notation of block + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def EfficientNetB0(is_test=False, padding_type='SAME', override_params=None): + model = EfficientNet(name='b0', is_test=is_test, padding_type=padding_type, override_params=override_params) + return model + + +def EfficientNetB1(is_test=False, padding_type='SAME', override_params=None): + model = EfficientNet(name='b1', is_test=is_test, padding_type=padding_type, override_params=override_params) + return model + + +def EfficientNetB2(is_test=False, padding_type='SAME', override_params=None): + model = EfficientNet(name='b2', is_test=is_test, padding_type=padding_type, override_params=override_params) + return model + + +def EfficientNetB3(is_test=False, padding_type='SAME', override_params=None): + model = EfficientNet(name='b3', is_test=is_test, padding_type=padding_type, override_params=override_params) + return model + + +def EfficientNetB4(is_test=False, padding_type='SAME', override_params=None): + model = EfficientNet(name='b4', is_test=is_test, padding_type=padding_type, override_params=override_params) + return model + + +def EfficientNetB5(is_test=False, padding_type='SAME', override_params=None): + model = EfficientNet(name='b5', is_test=is_test, padding_type=padding_type, override_params=override_params) + return model + + +def EfficientNetB6(is_test=False, padding_type='SAME', override_params=None): + model = EfficientNet(name='b6', is_test=is_test, padding_type=padding_type, override_params=override_params) + return model + + +def EfficientNetB7(is_test=False, padding_type='SAME', override_params=None): + model = EfficientNet(name='b7', is_test=is_test, padding_type=padding_type, override_params=override_params) + return model \ No newline at end of file diff --git a/PaddleCV/image_classification/models/layers.py b/PaddleCV/image_classification/models/layers.py new file mode 100644 index 00000000..5900f8cc --- /dev/null +++ b/PaddleCV/image_classification/models/layers.py @@ -0,0 +1,222 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +import math +import warnings + +def initial_type(name, + input, + op_type, + fan_out, + init="google", + use_bias=False, + filter_size=0, + stddev=0.02): + if init == "kaiming": + if op_type == 'conv': + fan_in = input.shape[1] * filter_size * filter_size + elif op_type == 'deconv': + fan_in = fan_out * filter_size * filter_size + else: + if len(input.shape) > 2: + fan_in = input.shape[1] * input.shape[2] * input.shape[3] + else: + fan_in = input.shape[1] + bound = 1 / math.sqrt(fan_in) + param_attr = fluid.ParamAttr( + name=name + "_weights", + initializer=fluid.initializer.Uniform( + low=-bound, high=bound)) + if use_bias == True: + bias_attr = fluid.ParamAttr( + name=name + '_offset', + initializer=fluid.initializer.Uniform( + low=-bound, high=bound)) + else: + bias_attr = False + elif init == 'google': + n = filter_size * filter_size * fan_out + param_attr = fluid.ParamAttr( + name=name + "_weights", + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=math.sqrt(2.0 / n))) + if use_bias == True: + bias_attr = fluid.ParamAttr( + name=name + "_offset", initializer=fluid.initializer.Constant(0.0)) + else: + bias_attr = False + + else: + param_attr = fluid.ParamAttr( + name=name + "_weights", + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=stddev)) + if use_bias == True: + bias_attr = fluid.ParamAttr( + name=name + "_offset", initializer=fluid.initializer.Constant(0.0)) + else: + bias_attr = False + return param_attr, bias_attr + +def cal_padding(img_size, stride, filter_size, dilation=1): + """Calculate padding size.""" + if img_size % stride == 0: + out_size = max(filter_size - stride, 0) + else: + out_size = max(filter_size - (img_size % stride), 0) + return out_size // 2, out_size - out_size // 2 + +def init_batch_norm_layer(name="batch_norm"): + param_attr = fluid.ParamAttr( + name=name + '_scale', initializer=fluid.initializer.Constant(1.0)) + bias_attr = fluid.ParamAttr( + name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0)) + return param_attr, bias_attr + +def init_fc_layer(fout, name='fc'): + n = fout # fan-out + init_range = 1.0 / math.sqrt(n) + + param_attr = fluid.ParamAttr( + name=name + '_weights', initializer=fluid.initializer.UniformInitializer( + low=-init_range, high=init_range)) + bias_attr = fluid.ParamAttr( + name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0)) + return param_attr, bias_attr + +def norm_layer(input, norm_type='batch_norm', name=None): + if norm_type == 'batch_norm': + param_attr = fluid.ParamAttr( + name=name + '_weights', initializer=fluid.initializer.Constant(1.0)) + bias_attr = fluid.ParamAttr( + name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0)) + return fluid.layers.batch_norm( + input, + param_attr=param_attr, + bias_attr=bias_attr, + moving_mean_name=name + '_mean', + moving_variance_name=name + '_variance') + + elif norm_type == 'instance_norm': + helper = fluid.layer_helper.LayerHelper("instance_norm", **locals()) + dtype = helper.input_dtype() + epsilon = 1e-5 + mean = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True) + var = fluid.layers.reduce_mean( + fluid.layers.square(input - mean), dim=[2, 3], keep_dim=True) + if name is not None: + scale_name = name + "_scale" + offset_name = name + "_offset" + scale_param = fluid.ParamAttr( + name=scale_name, + initializer=fluid.initializer.Constant(1.0), + trainable=True) + offset_param = fluid.ParamAttr( + name=offset_name, + initializer=fluid.initializer.Constant(0.0), + trainable=True) + scale = helper.create_parameter( + attr=scale_param, shape=input.shape[1:2], dtype=dtype) + offset = helper.create_parameter( + attr=offset_param, shape=input.shape[1:2], dtype=dtype) + + tmp = fluid.layers.elementwise_mul(x=(input - mean), y=scale, axis=1) + tmp = tmp / fluid.layers.sqrt(var + epsilon) + tmp = fluid.layers.elementwise_add(tmp, offset, axis=1) + return tmp + else: + raise NotImplementedError("norm tyoe: [%s] is not support" % norm_type) + + +def conv2d(input, + num_filters=64, + filter_size=7, + stride=1, + stddev=0.02, + padding=0, + groups=None, + name="conv2d", + norm=None, + act=None, + relufactor=0.0, + use_bias=False, + padding_type=None, + initial="normal", + use_cudnn=True): + + if padding != 0 and padding_type != None: + warnings.warn( + 'padding value and padding type are set in the same time, and the final padding width and padding height are computed by padding_type' + ) + + param_attr, bias_attr = initial_type( + name=name, + input=input, + op_type='conv', + fan_out=num_filters, + init=initial, + use_bias=use_bias, + filter_size=filter_size, + stddev=stddev) + + def get_padding(filter_size, stride=1, dilation=1): + padding = ((stride - 1) + dilation * (filter_size - 1)) // 2 + return padding + + need_crop = False + if padding_type == "SAME": + top_padding, bottom_padding = cal_padding(input.shape[2], stride, + filter_size) + left_padding, right_padding = cal_padding(input.shape[2], stride, + filter_size) + height_padding = bottom_padding + width_padding = right_padding + if top_padding != bottom_padding or left_padding != right_padding: + height_padding = top_padding + stride + width_padding = left_padding + stride + need_crop = True + padding = [height_padding, width_padding] + elif padding_type == "VALID": + height_padding = 0 + width_padding = 0 + padding = [height_padding, width_padding] + elif padding_type == "DYNAMIC": + padding = get_padding(filter_size, stride) + else: + padding = padding + + conv = fluid.layers.conv2d( + input, + num_filters, + filter_size, + groups=groups, + name=name, + stride=stride, + padding=padding, + use_cudnn=use_cudnn, + param_attr=param_attr, + bias_attr=bias_attr) + + if need_crop: + conv = conv[:, :, 1:, 1:] + + if norm is not None: + conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm") + if act == 'relu': + conv = fluid.layers.relu(conv, name=name + '_relu') + elif act == 'leaky_relu': + conv = fluid.layers.leaky_relu( + conv, alpha=relufactor, name=name + '_leaky_relu') + elif act == 'tanh': + conv = fluid.layers.tanh(conv, name=name + '_tanh') + elif act == 'sigmoid': + conv = fluid.layers.sigmoid(conv, name=name + '_sigmoid') + elif act == 'swish': + conv = fluid.layers.swish(conv, name=name + '_swish') + elif act == None: + conv = conv + else: + raise NotImplementedError("activation: [%s] is not support" %act) + + return conv \ No newline at end of file diff --git a/PaddleCV/image_classification/reader.py b/PaddleCV/image_classification/reader.py index 9ff65bf7..997f788f 100644 --- a/PaddleCV/image_classification/reader.py +++ b/PaddleCV/image_classification/reader.py @@ -18,11 +18,12 @@ import random import functools import numpy as np import cv2 -import io -import signal import paddle -import paddle.fluid as fluid +from utils.autoaugment import ImageNetPolicy +from PIL import Image + +policy = None random.seed(0) np.random.seed(0) @@ -200,7 +201,6 @@ def create_mixup_reader(settings, rd): return mixup_reader - def process_image(sample, settings, mode, color_jitter, rotate): """ process_image """ @@ -215,7 +215,7 @@ def process_image(sample, settings, mode, color_jitter, rotate): if rotate: img = rotate_image(img) if crop_size > 0: - img = random_crop(img, crop_size, settings) + img = random_crop(img, crop_size, settings, interpolation=settings.interpolation) if color_jitter: img = distort_color(img) if np.random.randint(0, 2) == 1: @@ -223,10 +223,18 @@ def process_image(sample, settings, mode, color_jitter, rotate): else: if crop_size > 0: target_size = settings.resize_short_size - img = resize_short(img, target_size) + img = resize_short(img, target_size, interpolation=settings.interpolation) img = crop_image(img, target_size=crop_size, center=True) - img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 + img = img[:, :, ::-1] + + if 'use_aa' in settings and settings.use_aa and mode == 'train': + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + img = policy(img) + img = np.asarray(img) + + img = img.astype('float32').transpose((2, 0, 1)) / 255 img_mean = np.array(mean).reshape((3, 1, 1)) img_std = np.array(std).reshape((3, 1, 1)) img -= img_mean @@ -294,6 +302,11 @@ def train(settings): assert os.path.isfile( file_list), "{} doesn't exist, please check data list path".format( file_list) + + if 'use_aa' in settings and settings.use_aa: + global policy + policy = ImageNetPolicy() + reader = _reader_creator( settings, file_list, @@ -317,6 +330,7 @@ def val(settings): Returns: eval reader """ + file_list = os.path.join(settings.data_dir, 'val_list.txt') assert os.path.isfile( file_list), "{} doesn't exist, please check data list path".format( diff --git a/PaddleCV/image_classification/scripts/train/EfficientNetB0.sh b/PaddleCV/image_classification/scripts/train/EfficientNetB0.sh new file mode 100644 index 00000000..37a5d9a4 --- /dev/null +++ b/PaddleCV/image_classification/scripts/train/EfficientNetB0.sh @@ -0,0 +1,27 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export FLAGS_fast_eager_deletion_mode=1 +export FLAGS_eager_delete_tensor_gb=0.0 +export FLAGS_fraction_of_gpu_memory_to_use=0.96 + + +python -u train.py \ + --model=EfficientNet \ + --batch_size=512 \ + --test_batch_size=128 \ + --total_images=1281167 \ + --class_dim=1000 \ + --image_shape=3,224,224 \ + --resize_short_size=256 \ + --model_save_dir=output/ \ + --lr_strategy=exponential_decay_warmup \ + --lr=0.032 \ + --num_epochs=360 \ + --l2_decay=1e-5 \ + --use_label_smoothing=True \ + --label_smoothing_epsilon=0.1 \ + --use_ema=True \ + --ema_decay=0.9999 \ + --drop_connect_rate=0.1 \ + --padding_type="SAME" \ + --interpolation=2 \ + --use_aa=True diff --git a/PaddleCV/image_classification/train.py b/PaddleCV/image_classification/train.py index ae3e03f0..dfd0f591 100755 --- a/PaddleCV/image_classification/train.py +++ b/PaddleCV/image_classification/train.py @@ -20,9 +20,6 @@ import os import numpy as np import time import sys -import functools -import math - def set_paddle_flags(flags): for key, value in flags.items(): @@ -38,10 +35,6 @@ set_paddle_flags({ 'FLAGS_fraction_of_gpu_memory_to_use': 0.98 }) -import argparse -import functools -import subprocess - import paddle import paddle.fluid as fluid import reader @@ -63,7 +56,13 @@ def build_program(is_train, main_prog, startup_prog, args): train mode: [Loss, global_lr, py_reader] test mode: [Loss, py_reader] """ - model = models.__dict__[args.model]() + if args.model.startswith('EfficientNet'): + is_test = False if is_train else True + override_params = {"drop_connect_rate": args.drop_connect_rate} + padding_type = args.padding_type + model = models.__dict__[args.model](is_test=is_test, override_params=override_params, padding_type=padding_type) + else: + model = models.__dict__[args.model]() with fluid.program_guard(main_prog, startup_prog): if args.random_seed: main_prog.random_seed = args.random_seed @@ -79,9 +78,50 @@ def build_program(is_train, main_prog, startup_prog, args): global_lr = optimizer._global_learning_rate() global_lr.persistable = True loss_out.append(global_lr) + if args.use_ema: + global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter() + ema = ExponentialMovingAverage(args.ema_decay, thres_steps=global_steps) + ema.update() + loss_out.append(ema) loss_out.append(py_reader) return loss_out +def validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record): + test_batch_time_record = [] + test_batch_metrics_record = [] + test_batch_id = 0 + test_py_reader.start() + try: + while True: + t1 = time.time() + test_batch_metrics = exe.run(program=test_prog, + fetch_list=test_fetch_list) + t2 = time.time() + test_batch_elapse = t2 - t1 + test_batch_time_record.append(test_batch_elapse) + + test_batch_metrics_avg = np.mean( + np.array(test_batch_metrics), axis=1) + test_batch_metrics_record.append(test_batch_metrics_avg) + + print_info(pass_id, test_batch_id, args.print_step, + test_batch_metrics_avg, test_batch_elapse, "batch") + sys.stdout.flush() + test_batch_id += 1 + + except fluid.core.EOFException: + test_py_reader.reset() + #train_epoch_time_avg = np.mean(np.array(train_batch_time_record)) + train_epoch_metrics_avg = np.mean( + np.array(train_batch_metrics_record), axis=0) + + test_epoch_time_avg = np.mean(np.array(test_batch_time_record)) + test_epoch_metrics_avg = np.mean( + np.array(test_batch_metrics_record), axis=0) + + print_info(pass_id, 0, 0, + list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg), + test_epoch_time_avg, "epoch") def train(args): """Train model @@ -99,7 +139,11 @@ def train(args): startup_prog=startup_prog, args=args) train_py_reader = train_out[-1] - train_fetch_vars = train_out[:-1] + if args.use_ema: + train_fetch_vars = train_out[:-2] + ema = train_out[-2] + else: + train_fetch_vars = train_out[:-1] train_fetch_list = [var.name for var in train_fetch_vars] @@ -143,11 +187,8 @@ def train(args): for pass_id in range(args.num_epochs): train_batch_id = 0 - test_batch_id = 0 train_batch_time_record = [] - test_batch_time_record = [] train_batch_metrics_record = [] - test_batch_metrics_record = [] train_py_reader.start() @@ -171,38 +212,13 @@ def train(args): except fluid.core.EOFException: train_py_reader.reset() - test_py_reader.start() - try: - while True: - t1 = time.time() - test_batch_metrics = exe.run(program=test_prog, - fetch_list=test_fetch_list) - t2 = time.time() - test_batch_elapse = t2 - t1 - test_batch_time_record.append(test_batch_elapse) - - test_batch_metrics_avg = np.mean( - np.array(test_batch_metrics), axis=1) - test_batch_metrics_record.append(test_batch_metrics_avg) + if args.use_ema: + print('ExponentialMovingAverage validate start...') + with ema.apply(exe): + validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record) + print('ExponentialMovingAverage validate over!') - print_info(pass_id, test_batch_id, args.print_step, - test_batch_metrics_avg, test_batch_elapse, "batch") - sys.stdout.flush() - test_batch_id += 1 - - except fluid.core.EOFException: - test_py_reader.reset() - train_epoch_time_avg = np.mean(np.array(train_batch_time_record)) - train_epoch_metrics_avg = np.mean( - np.array(train_batch_metrics_record), axis=0) - - test_epoch_time_avg = np.mean(np.array(test_batch_time_record)) - test_epoch_metrics_avg = np.mean( - np.array(test_batch_metrics_record), axis=0) - - print_info(pass_id, 0, 0, - list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg), - 0, "epoch") + validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record) #For now, save model per epoch. if pass_id % args.save_step == 0: save_model(args, exe, train_prog, pass_id) diff --git a/PaddleCV/image_classification/utils/__init__.py b/PaddleCV/image_classification/utils/__init__.py index 995da6a3..61ca1cbb 100644 --- a/PaddleCV/image_classification/utils/__init__.py +++ b/PaddleCV/image_classification/utils/__init__.py @@ -11,5 +11,5 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. -from .optimizer import cosine_decay, lr_warmup, cosine_decay_with_warmup, Optimizer, create_optimizer -from .utility import add_arguments, print_arguments, parse_args, check_gpu, check_args, init_model, save_model, create_pyreader, print_info, best_strategy_compiled, init_model, save_model +from .optimizer import cosine_decay, lr_warmup, cosine_decay_with_warmup, exponential_decay_with_warmup, Optimizer, create_optimizer +from .utility import add_arguments, print_arguments, parse_args, check_gpu, check_args, init_model, save_model, create_pyreader, print_info, best_strategy_compiled, init_model, save_model, ExponentialMovingAverage diff --git a/PaddleCV/image_classification/utils/autoaugment.py b/PaddleCV/image_classification/utils/autoaugment.py new file mode 100644 index 00000000..27c6fff3 --- /dev/null +++ b/PaddleCV/image_classification/utils/autoaugment.py @@ -0,0 +1,243 @@ +""" +This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py +""" +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random + + +class ImageNetPolicy(object): + """ Randomly choose one of the best 24 Sub-policies on ImageNet. + + Example: + >>> policy = ImageNetPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> ImageNetPolicy(), + >>> transforms.ToTensor()]) + """ + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + + SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), + SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), + SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), + SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), + + SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), + SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), + SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + + SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), + SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), + SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), + SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), + SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), + + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment ImageNet Policy" + + +class CIFAR10Policy(object): + """ Randomly choose one of the best 25 Sub-policies on CIFAR10. + + Example: + >>> policy = CIFAR10Policy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> CIFAR10Policy(), + >>> transforms.ToTensor()]) + """ + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), + SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), + SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), + SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), + + SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), + SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), + SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), + SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), + + SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), + SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), + SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), + SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), + SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), + + SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), + SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), + SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), + SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), + SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), + + SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), + SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment CIFAR10 Policy" + + +class SVHNPolicy(object): + """ Randomly choose one of the best 25 Sub-policies on SVHN. + + Example: + >>> policy = SVHNPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> SVHNPolicy(), + >>> transforms.ToTensor()]) + """ + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), + + SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), + SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), + SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), + + SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), + SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), + SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), + SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), + + SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), + SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), + SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), + SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), + + SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), + SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), + SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), + SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), + SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment SVHN Policy" + + +class SubPolicy(object): + def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): + ranges = { + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 150 / 331, 10), + "translateY": np.linspace(0, 150 / 331, 10), + "rotate": np.linspace(0, 30, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), + "solarize": np.linspace(256, 0, 10), + "contrast": np.linspace(0.0, 0.9, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.9, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10 + } + + # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) + + func = { + "shearX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + self.p1 = p1 + self.operation1 = func[operation1] + self.magnitude1 = ranges[operation1][magnitude_idx1] + self.p2 = p2 + self.operation2 = func[operation2] + self.magnitude2 = ranges[operation2][magnitude_idx2] + + def __call__(self, img): + if random.random() < self.p1: img = self.operation1(img, self.magnitude1) + if random.random() < self.p2: img = self.operation2(img, self.magnitude2) + return img diff --git a/PaddleCV/image_classification/utils/optimizer.py b/PaddleCV/image_classification/utils/optimizer.py index 8efa16a4..dd33bd37 100644 --- a/PaddleCV/image_classification/utils/optimizer.py +++ b/PaddleCV/image_classification/utils/optimizer.py @@ -67,6 +67,33 @@ def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120): fluid.layers.tensor.assign(input=decayed_lr, output=lr) return lr +def exponential_decay_with_warmup(learning_rate, step_each_epoch, decay_epochs, decay_rate=0.97, warm_up_epoch=5.0): + """Applies exponential decay to the learning rate. + """ + global_step = _decay_step_counter() + lr = fluid.layers.tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate") + + warmup_epoch = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True) + + with init_on_cpu(): + epoch = ops.floor(global_step / step_each_epoch) + with fluid.layers.control_flow.Switch() as switch: + with switch.case(epoch < warmup_epoch): + decayed_lr = learning_rate * (global_step / (step_each_epoch * warmup_epoch)) + fluid.layers.assign(input=decayed_lr, output=lr) + with switch.default(): + div_res = (global_step - warmup_epoch * step_each_epoch) / decay_epochs + div_res = ops.floor(div_res) + decayed_lr = learning_rate * (decay_rate ** div_res) + fluid.layers.assign(input=decayed_lr, output=lr) + + return lr def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): """ Applies linear learning rate warmup for distributed training @@ -123,8 +150,11 @@ class Optimizer(object): self.momentum_rate = args.momentum_rate self.step_epochs = args.step_epochs self.num_epochs = args.num_epochs - + self.warm_up_epochs = args.warm_up_epochs + self.decay_epochs = args.decay_epochs + self.decay_rate = args.decay_rate self.total_images = args.total_images + self.step = int(math.ceil(float(self.total_images) / self.batch_size)) def piecewise_decay(self): @@ -176,6 +206,28 @@ class Optimizer(object): regularization=fluid.regularizer.L2Decay(self.l2_decay)) return optimizer + def exponential_decay_warmup(self): + """exponential decay with warmup + + Returns: + a exponential_decay_with_warmup optimizer + """ + + learning_rate = exponential_decay_with_warmup( + learning_rate=self.lr, + step_each_epoch=self.step, + decay_epochs=self.step * self.decay_epochs, + decay_rate=self.decay_rate, + warm_up_epoch=self.warm_up_epochs) + optimizer = fluid.optimizer.RMSProp( + learning_rate=learning_rate, + regularization=fluid.regularizer.L2Decay(self.l2_decay), + momentum=self.momentum_rate, + rho=0.9, + epsilon=0.001 + ) + return optimizer + def linear_decay(self): """linear decay with Momentum optimizer diff --git a/PaddleCV/image_classification/utils/utility.py b/PaddleCV/image_classification/utils/utility.py index 1e3a8114..7f6b1704 100644 --- a/PaddleCV/image_classification/utils/utility.py +++ b/PaddleCV/image_classification/utils/utility.py @@ -29,7 +29,9 @@ import signal import paddle import paddle.fluid as fluid - +from paddle.fluid.wrapped_decorator import signature_safe_contextmanager +from paddle.fluid.framework import Program, program_guard, name_scope, default_main_program +from paddle.fluid import unique_name, layers def print_arguments(args): """Print argparse's arguments. @@ -103,7 +105,12 @@ def parse_args(): add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") add_arg('l2_decay', float, 1e-4, "The l2_decay parameter.") add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") + add_arg('warm_up_epochs', float, 5.0, "The value of warm up epochs") + add_arg('decay_epochs', float, 2.4, "Decay epochs of exponential decay learning rate scheduler") + add_arg('decay_rate', float, 0.97, "Decay rate of exponential decay learning rate scheduler") + add_arg('drop_connect_rate', float, 0.2, "The value of drop connect rate") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") + # READER AND PREPROCESS add_arg('lower_scale', float, 0.08, "The value of lower_scale in ramdom_crop") add_arg('lower_ratio', float, 3./4., "The value of lower_ratio in ramdom_crop") @@ -115,6 +122,7 @@ def parse_args(): add_arg('reader_thread', int, 8, "The number of multi thread reader") add_arg('reader_buf_size', int, 2048, "The buf size of multi thread reader") add_arg('interpolation', int, None, "The interpolation mode") + add_arg('use_aa', bool, False, "Whether to use auto augment") parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456, 0.406], help="The mean of input image data") parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data") @@ -127,6 +135,9 @@ def parse_args(): #NOTE: (2019/08/08) temporary disable use_distill #add_arg('use_distill', bool, False, "Whether to use distill") add_arg('random_seed', int, None, "random seed") + add_arg('use_ema', bool, False, "Whether to use ExponentialMovingAverage.") + add_arg('ema_decay', float, 0.9999, "The value of ema decay rate") + add_arg('padding_type', str, "SAME", "Padding type of convolution") # yapf: enable args = parser.parse_args() @@ -170,7 +181,7 @@ def check_args(args): # check learning rate strategy lr_strategy_list = [ - "piecewise_decay", "cosine_decay", "linear_decay", "cosine_decay_warmup" + "piecewise_decay", "cosine_decay", "linear_decay", "cosine_decay_warmup", "exponential_decay_warmup" ] if args.lr_strategy not in lr_strategy_list: warnings.warn( @@ -186,6 +197,11 @@ def check_args(args): 0, 1, 2, 3, 4 ], "Wrong interpolation, please set:\n0: cv2.INTER_NEAREST\n1: cv2.INTER_LINEAR\n2: cv2.INTER_CUBIC\n3: cv2.INTER_AREA\n4: cv2.INTER_LANCZOS4" + if args.padding_type: + assert args.padding_type in [ + "SAME", "VALID", "DYNAMIC" + ], "Wrong padding_type, please set:\nSAME\nVALID\nDYNAMIC" + assert args.checkpoint is None or args.pretrained_model is None, "Do not init model by checkpoint and pretrained_model both." # check pretrained_model path for loading @@ -381,3 +397,141 @@ def best_strategy_compiled(args, program, loss): exec_strategy=exec_strategy) return compiled_program + + +class ExponentialMovingAverage(object): + + def __init__(self, decay=0.999, thres_steps=None, zero_debias=False, name=None): + self._decay = decay + self._thres_steps = thres_steps + self._name = name if name is not None else '' + self._decay_var = self._get_ema_decay() + + self._params_tmps = [] + for param in default_main_program().global_block().all_parameters(): + if param.do_model_average != False: + tmp = param.block.create_var( + name=unique_name.generate(".".join( + [self._name + param.name, 'ema_tmp'])), + dtype=param.dtype, + persistable=False, + stop_gradient=True) + self._params_tmps.append((param, tmp)) + + self._ema_vars = {} + for param, tmp in self._params_tmps: + with param.block.program._optimized_guard( + [param, tmp]), name_scope('moving_average'): + self._ema_vars[param.name] = self._create_ema_vars(param) + + self.apply_program = Program() + block = self.apply_program.global_block() + with program_guard(main_program=self.apply_program): + decay_pow = self._get_decay_pow(block) + for param, tmp in self._params_tmps: + param = block._clone_variable(param) + tmp = block._clone_variable(tmp) + ema = block._clone_variable(self._ema_vars[param.name]) + layers.assign(input=param, output=tmp) + # bias correction + if zero_debias: + ema = ema / (1.0 - decay_pow) + layers.assign(input=ema, output=param) + + self.restore_program = Program() + block = self.restore_program.global_block() + with program_guard(main_program=self.restore_program): + for param, tmp in self._params_tmps: + tmp = block._clone_variable(tmp) + param = block._clone_variable(param) + layers.assign(input=tmp, output=param) + + def _get_ema_decay(self): + with default_main_program()._lr_schedule_guard(): + decay_var = layers.tensor.create_global_var( + shape=[1], + value=self._decay, + dtype='float32', + persistable=True, + name="scheduled_ema_decay_rate") + + if self._thres_steps is not None: + decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0) + with layers.control_flow.Switch() as switch: + with switch.case(decay_t < self._decay): + layers.tensor.assign(decay_t, decay_var) + with switch.default(): + layers.tensor.assign( + np.array( + [self._decay], dtype=np.float32), + decay_var) + return decay_var + + def _get_decay_pow(self, block): + global_steps = layers.learning_rate_scheduler._decay_step_counter() + decay_var = block._clone_variable(self._decay_var) + decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1) + return decay_pow_acc + + def _create_ema_vars(self, param): + param_ema = layers.create_global_var( + name=unique_name.generate(self._name + param.name + '_ema'), + shape=param.shape, + value=0.0, + dtype=param.dtype, + persistable=True) + + return param_ema + + def update(self): + """ + Update Exponential Moving Average. Should only call this method in + train program. + """ + param_master_emas = [] + for param, tmp in self._params_tmps: + with param.block.program._optimized_guard( + [param, tmp]), name_scope('moving_average'): + param_ema = self._ema_vars[param.name] + if param.name + '.master' in self._ema_vars: + master_ema = self._ema_vars[param.name + '.master'] + param_master_emas.append([param_ema, master_ema]) + else: + ema_t = param_ema * self._decay_var + param * ( + 1 - self._decay_var) + layers.assign(input=ema_t, output=param_ema) + + # for fp16 params + for param_ema, master_ema in param_master_emas: + default_main_program().global_block().append_op( + type="cast", + inputs={"X": master_ema}, + outputs={"Out": param_ema}, + attrs={ + "in_dtype": master_ema.dtype, + "out_dtype": param_ema.dtype + }) + + @signature_safe_contextmanager + def apply(self, executor, need_restore=True): + """ + Apply moving average to parameters for evaluation. + + Args: + executor (Executor): The Executor to execute applying. + need_restore (bool): Whether to restore parameters after applying. + """ + executor.run(self.apply_program) + try: + yield + finally: + if need_restore: + self.restore(executor) + + def restore(self, executor): + """Restore parameters. + + Args: + executor (Executor): The Executor to execute restoring. + """ + executor.run(self.restore_program) -- GitLab