提交 f6f1c3ad 编写于 作者: L LielinJiang 提交者: ruri

Add efficientNet to image classification. (#3358)

Add EfficientNet
上级 a546b663
......@@ -115,6 +115,8 @@ bash run.sh train 模型名
* **l2_decay**: l2_decay值,默认值: 1e-4
* **momentum_rate**: momentum_rate值,默认值: 0.9
* **step_epochs**: piecewise dacay的decay step,默认值:[30,60,90]
* **decay_epochs**: exponential decay的间隔epoch数, 默认值: 2.4.
* **decay_rate**: exponential decay的下降率, 默认值: 0.97.
数据读取器和预处理配置:
......@@ -125,6 +127,7 @@ bash run.sh train 模型名
* **crop_size**: 指定裁剪的大小,默认值:224
* **use_mixup**: 是否对数据进行mixup处理,默认值: False
* **mixup_alpha**: 指定mixup处理时的alpha值,默认值: 0.2
* **use_aa**: 是否对数据进行auto augment处理. 默认值: False.
* **reader_thread**: 多线程reader的线程数量,默认值: 8
* **reader_buf_size**: 多线程reader的buf_size, 默认值: 2048
* **interpolation**: 插值方法, 默认值:None
......@@ -138,6 +141,9 @@ bash run.sh train 模型名
* **use_label_smoothing**: 是否对数据进行label smoothing处理,默认值: False
* **label_smoothing_epsilon**: label_smoothing的epsilon, 默认值:0.1
* **random_seed**: 随机数种子, 默认值: 1000
* **padding_type**: efficientNet中卷积操作的padding方式, 默认值: "SAME".
* **use_ema**: 是否在更新模型参数时使用ExponentialMovingAverage. 默认值: False.
* **ema_decay**: ExponentialMovingAverage的decay rate. 默认值: 0.9999.
**数据读取器说明:** 数据读取器定义在```reader.py```文件中,现在默认基于cv2的数据读取器, 在[训练阶段](#模型训练),默认采用的增广方式是随机裁剪与水平翻转, 而在[模型评估](#模型评估)[模型预测](#模型预测)阶段用的默认方式是中心裁剪。当前支持的数据增广方式有:
......@@ -147,6 +153,7 @@ bash run.sh train 模型名
* 中心裁剪
* 长宽调整
* 水平翻转
* 自动增广
### 参数微调
......@@ -170,6 +177,20 @@ python eval.py \
```
注意:根据具体模型和任务添加并调整其他参数
### 指数滑动平均的模型评估
注意: 如果你使用指数滑动平均来训练模型(--use_ema=True),并且想要评估指数滑动平均后的模型,需要使用ema_clean.py将训练中保存下来的ema模型名字转换成原始模型参数的名字。
```
python ema_clean.py \
--ema_model_dir=your_ema_model_dir \
--cleaned_model_dir=your_cleaned_model_dir
python eval.py \
--model=model_name \
--pretrained_model=your_cleaned_model_dir
```
### 模型预测
模型预测(Infer)可以获取一个模型的预测分数或者图像的特征,可以下载[已发布模型及其性能](#已发布模型及其性能)并且设置```path_to_pretrain_model```为模型所在路径。运行如下的命令获得预测结果:
......@@ -361,6 +382,17 @@ PaddlePaddle/Models ImageClassification 支持自定义数据
|[ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x48d_wsl_pretrained.tar) | 85.37% | 97.69% | 161.722 | |
|[Fix_ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/Fix_ResNeXt101_32x48d_wsl_pretrained.tar) | 86.26% | 97.97% | 236.091 | |
### EfficientNet Series
|Model | Top-1 | Top-5 | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) |
|- |:-: |:-: |:-: |:-: |
|[EfficientNetB0](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB0_pretrained.tar) | 77.38% | 93.31% | 10.303 | 4.334 |
|[EfficientNetB1](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB1_pretrained.tar) | 79.15% | 94.41% | 15.626 | 6.502 |
|[EfficientNetB2](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB2_pretrained.tar) | 79.85% | 94.74% | 17.847 | 7.558 |
|[EfficientNetB3](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB3_pretrained.tar) | 81.15% | 95.41% | 25.993 | 10.937 |
|[EfficientNetB4](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB4_pretrained.tar) | 82.85% | 96.23% | 47.734 | 18.536 |
|[EfficientNetB5](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB5_pretrained.tar) | 83.62% | 96.72% | 88.578 | 32.102 |
|[EfficientNetB6](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB6_pretrained.tar) | 84.00% | 96.88% | 138.670 | 51.059 |
|[EfficientNetB7](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB7_pretrained.tar) | 84.30% | 96.89% | 234.364 | 82.107 |
## FAQ
......@@ -400,6 +432,7 @@ PaddlePaddle/Models ImageClassification 支持自定义数据
- SqueezeNet: [SQUEEZENET: ALEXNET-LEVEL ACCURACY WITH 50X FEWER PARAMETERS AND <0.5MB MODEL SIZE](https://arxiv.org/abs/1602.07360), Forrest N. Iandola, Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, Kurt Keutzer
- ResNeXt101_wsl: [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), Dhruv Mahajan, Ross Girshick, Vignesh Ramanathan, Kaiming He, Manohar Paluri, Yixuan Li, Ashwin Bharambe, Laurens van der Maaten
- Fix_ResNeXt101_wsl: [Fixing the train-test resolution discrepancy](https://arxiv.org/abs/1906.06423), Hugo Touvron, Andrea Vedaldi, Matthijs Douze, Herve ́ Je ́gou
- EfficientNet: [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946), Mingxing Tan, Quoc V. Le
## 版本更新
- 2018/12/03 **Stage1**: 更新AlexNet,ResNet50,ResNet101,MobileNetV1
......@@ -412,6 +445,7 @@ PaddlePaddle/Models ImageClassification 支持自定义数据
- 2019/07/19 **Stage6**: 更新ShuffleNetV2_x0_25,ShuffleNetV2_x0_33,ShuffleNetV2_x0_5,ShuffleNetV2_x1_0,ShuffleNetV2_x1_5,ShuffleNetV2_x2_0,MobileNetV2_x0_25,MobileNetV2_x1_5,MobileNetV2_x2_0,ResNeXt50_vd_64x4d,ResNeXt101_32x4d,ResNeXt152_32x4d
- 2019/08/01 **Stage7**: 更新DarkNet53,DenseNet121,Densenet161,DenseNet169,DenseNet201,DenseNet264,SqueezeNet1_0,SqueezeNet1_1,ResNeXt50_vd_32x4d,ResNeXt152_64x4d,ResNeXt101_32x8d_wsl,ResNeXt101_32x16d_wsl,ResNeXt101_32x32d_wsl,ResNeXt101_32x48d_wsl,Fix_ResNeXt101_32x48d_wsl
- 2019/09/11 **Stage8**: 更新ResNet18_vd,ResNet34_vd,MobileNetV1_x0_25,MobileNetV1_x0_5,MobileNetV1_x0_75,MobileNetV2_x0_75,MobilenNetV3_small_x1_0,DPN68,DPN92,DPN98,DPN107,DPN131,ResNeXt101_vd_32x4d,ResNeXt152_vd_64x4d,Xception65,Xception71,Xception41_deeplab,Xception65_deeplab,SE_ResNet50_vd
- 2019/09/20 更新EfficientNet
## 如何贡献代码
......
......@@ -106,7 +106,9 @@ Solver and hyperparameters:
* **lr**: initialized learning rate. Default: 0.1.
* **l2_decay**: L2_decay parameter. Default: 1e-4.
* **momentum_rate**: momentum_rate. Default: 0.9.
* **step_epochs**: decay step of piecewise step, Default: [30,60,90]
* **step_epochs**: decay step of piecewise step, Default: [30,60,90].
* **decay_epochs**: decay epoch of exponential decay, Default: 2.4.
* **decay_rate**: decay rate of exponential decay, Default: 0.97.
Reader and preprocess:
......@@ -117,6 +119,7 @@ Reader and preprocess:
* **crop_size**: the crop size, Default: 224.
* **use_mixup**: whether to use mixup data processing or not. Default:False.
* **mixup_alpha**: the mixup_alpha parameter. Default: 0.2.
* **use_aa**: whether to use auto augment data processing or not. Default:False.
* **reader_thread**: the number of threads in multi thread reader, Default: 8
* **reader_buf_size**: the buff size of multi thread reader, Default: 2048
* **interpolation**: interpolation method, Default: None
......@@ -129,7 +132,10 @@ Switch:
* **use_gpu**: whether to use GPU or not. Default: True.
* **use_label_smoothing**: whether to use label_smoothing or not. Default:False.
* **label_smoothing_epsilon**: the label_smoothing_epsilon. Default:0.1.
* **random_seed**: random seed for debugging, Default: 1000
* **random_seed**: random seed for debugging, Default: 1000.
* **padding_type**: padding type of convolution for efficientNet, Default: "SAME".
* **use_ema**: whether to use ExponentialMovingAverage or not. Default: False.
* **ema_decay**: the value of ExponentialMovingAverage decay rate. Default: 0.9999.
**data reader introduction:** Data reader is defined in ```reader.py```, default reader is implemented by opencv. In the [Training](#training) Stage, random crop and flipping are applied, while center crop is applied in the [Evaluation](#evaluation) and [Inference](#inference) stages. Supported data augmentation includes:
......@@ -139,6 +145,7 @@ Switch:
* center crop
* resize
* flipping
* auto augment
### Finetuning
......@@ -164,6 +171,20 @@ python eval.py \
Note: Add and adjust other parameters accroding to specific models and tasks.
### ExponentialMovingAverage Evaluation
Note: if you train model with flag use_ema, and you want to evaluate your ExponentialMovingAverage model, you should clean your saved model first.
```
python ema_clean.py \
--ema_model_dir=your_ema_model_dir \
--cleaned_model_dir=your_cleaned_model_dir
python eval.py \
--model=model_name \
--pretrained_model=your_cleaned_model_dir
```
### Inference
**some Inference stage unique parameters**
......@@ -343,6 +364,18 @@ Pretrained models can be downloaded by clicking related model names.
|[ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_32x48d_wsl_pretrained.tar) | 85.37% | 97.69% | 161.722 | |
|[Fix_ResNeXt101_32x48d_wsl](https://paddle-imagenet-models-name.bj.bcebos.com/Fix_ResNeXt101_32x48d_wsl_pretrained.tar) | 86.26% | 97.97% | 236.091 | |
### EfficientNet Series
|Model | Top-1 | Top-5 | Paddle Fluid inference time(ms) | Paddle TensorRT inference time(ms) |
|- |:-: |:-: |:-: |:-: |
|[EfficientNetB0](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB0_pretrained.tar) | 77.38% | 93.31% | 10.303 | 4.334 |
|[EfficientNetB1](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB1_pretrained.tar) | 79.15% | 94.41% | 15.626 | 6.502 |
|[EfficientNetB2](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB2_pretrained.tar) | 79.85% | 94.74% | 17.847 | 7.558 |
|[EfficientNetB3](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB3_pretrained.tar) | 81.15% | 95.41% | 25.993 | 10.937 |
|[EfficientNetB4](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB4_pretrained.tar) | 82.85% | 96.23% | 47.734 | 18.536 |
|[EfficientNetB5](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB5_pretrained.tar) | 83.62% | 96.72% | 88.578 | 32.102 |
|[EfficientNetB6](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB6_pretrained.tar) | 84.00% | 96.88% | 138.670 | 51.059 |
|[EfficientNetB7](https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB7_pretrained.tar) | 84.30% | 96.89% | 234.364 | 82.107 |
## FAQ
**Q:** How to solve this problem when I try to train a 6-classes dataset with indicating pretrained_model parameter ?
......@@ -374,6 +407,7 @@ Enforce failed. Expected x_dims[1] == labels_dims[1], but received x_dims[1]:100
- SqueezeNet: [SQUEEZENET: ALEXNET-LEVEL ACCURACY WITH 50X FEWER PARAMETERS AND <0.5MB MODEL SIZE](https://arxiv.org/abs/1602.07360), Forrest N. Iandola, Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, Kurt Keutzer
- ResNeXt101_wsl: [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), Dhruv Mahajan, Ross Girshick, Vignesh Ramanathan, Kaiming He, Manohar Paluri, Yixuan Li, Ashwin Bharambe, Laurens van der Maaten
- Fix_ResNeXt101_wsl: [Fixing the train-test resolution discrepancy](https://arxiv.org/abs/1906.06423), Hugo Touvron, Andrea Vedaldi, Matthijs Douze, Herve ́ Je ́gou
- EfficientNet: [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946), Mingxing Tan, Quoc V. Le
## Update
......@@ -387,6 +421,7 @@ Enforce failed. Expected x_dims[1] == labels_dims[1], but received x_dims[1]:100
- 2019/07/19 **Stage6**: Update ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, MobileNetV2_x0_25, MobileNetV2_x1_5, MobileNetV2_x2_0, ResNeXt50_vd_64x4d, ResNeXt101_32x4d, ResNeXt152_32x4d
- 2019/08/01 **Stage7**: Update DarkNet53, DenseNet121. Densenet161, DenseNet169, DenseNet201, DenseNet264, SqueezeNet1_0, SqueezeNet1_1, ResNeXt50_vd_32x4d, ResNeXt152_64x4d, ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl
- 2019/09/11 **Stage8**: Update ResNet18_vd,ResNet34_vd,MobileNetV1_x0_25,MobileNetV1_x0_5,MobileNetV1_x0_75,MobileNetV2_x0_75,MobilenNetV3_small_x1_0,DPN68,DPN92,DPN98,DPN107,DPN131,ResNeXt101_vd_32x4d,ResNeXt152_vd_64x4d,Xception65,Xception71,Xception41_deeplab,Xception65_deeplab,SE_ResNet50_vd
- 2019/09/20 Update EfficientNet
## Contribute
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import functools
import shutil
from utils import print_arguments, add_arguments
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('ema_model_dir', str, None, "The directory of model which use ExponentialMovingAverage to train")
add_arg('cleaned_model_dir', str, None, "The directory of cleaned model")
# yapf: enable
def main():
args = parser.parse_args()
print_arguments(args)
if not os.path.exists(args.cleaned_model_dir):
os.makedirs(args.cleaned_model_dir)
items = os.listdir(args.ema_model_dir)
for item in items:
if item.find('ema') > -1:
item_clean = item.replace('_ema_0', '')
shutil.copyfile(os.path.join(args.ema_model_dir, item),
os.path.join(args.cleaned_model_dir, item_clean))
if __name__ == '__main__':
main()
......@@ -46,6 +46,8 @@ add_arg('reader_buf_size', int, 2048, "The buf size of multi t
parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456, 0.406], help="The mean of input image data")
parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data")
add_arg('crop_size', int, 224, "The value of crop size")
add_arg('interpolation', int, None, "The interpolation mode")
add_arg('padding_type', str, "SAME", "Padding type of convolution")
# yapf: enable
......@@ -64,7 +66,10 @@ def eval(args):
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
if args.model.startswith('EfficientNet'):
model = models.__dict__[args.model](is_test=True, padding_type=args.padding_type)
else:
model = models.__dict__[args.model]()
if args.model == "GoogLeNet":
out0, out1, out2 = model.net(input=image, class_dim=args.class_dim)
......
......@@ -47,7 +47,8 @@ parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456
parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data")
add_arg('crop_size', int, 224, "The value of crop size")
add_arg('topk', int, 1, "topk")
add_arg('label_path', str, "./utils/tools/readable_label.txt", "readable label filepath")
add_arg('label_path', str, "./utils/tools/readable_label.txt", "readable label filepath")
add_arg('interpolation', int, None, "The interpolation mode")
# yapf: enable
......
......@@ -37,3 +37,4 @@ from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseN
from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
from .darknet import DarkNet53
from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl
from .efficientnet import EfficientNet, EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import collections
import re
import math
import copy
from .layers import conv2d, init_batch_norm_layer, init_fc_layer
__all__ = ['EfficientNet', 'EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', 'EfficientNetB4',
'EfficientNetB5', 'EfficientNetB6', 'EfficientNetB7']
GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
'num_classes', 'width_coefficient', 'depth_coefficient',
'depth_divisor', 'min_depth', 'drop_connect_rate', ])
BlockArgs = collections.namedtuple('BlockArgs', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
def efficientnet_params(model_name):
""" Map EfficientNet model name to parameter coefficients. """
params_dict = {
# Coefficients: width,depth,resolution,dropout
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
}
return params_dict[model_name]
def efficientnet(width_coefficient=None, depth_coefficient=None,
dropout_rate=0.2, drop_connect_rate=0.2):
""" Get block arguments according to parameter and coefficients. """
blocks_args = [
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
]
blocks_args = BlockDecoder.decode(blocks_args)
global_params = GlobalParams(
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
drop_connect_rate=drop_connect_rate,
num_classes=1000,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None
)
return blocks_args, global_params
def get_model_params(model_name, override_params):
""" Get the block args and global params for a given model """
if model_name.startswith('efficientnet'):
w, d, _, p = efficientnet_params(model_name)
blocks_args, global_params = efficientnet(width_coefficient=w, depth_coefficient=d, dropout_rate=p)
else:
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
if override_params:
global_params = global_params._replace(**override_params)
return blocks_args, global_params
def round_filters(filters, global_params):
""" Calculate and round number of filters based on depth multiplier. """
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
min_depth = min_depth or divisor
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor
return int(new_filters)
def round_repeats(repeats, global_params):
""" Round number of filters based on depth multiplier. """
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
class EfficientNet():
def __init__(self, name='b0', padding_type='SAME', override_params=None, is_test=False):
valid_names = ['b' + str(i) for i in range(8)]
assert name in valid_names, 'efficient name should be in b0~b7'
model_name = 'efficientnet-' + name
self._blocks_args, self._global_params = get_model_params(model_name, override_params)
self._bn_mom = self._global_params.batch_norm_momentum
self._bn_eps = self._global_params.batch_norm_epsilon
self.is_test = is_test
self.padding_type = padding_type
def net(self, input, class_dim=1000, is_test=False):
conv = self.extract_features(input, is_test=is_test)
out_channels = round_filters(1280, self._global_params)
conv = self.conv_bn_layer(conv,
num_filters=out_channels,
filter_size=1,
bn_act='swish',
bn_mom=self._bn_mom,
bn_eps=self._bn_eps,
padding_type=self.padding_type,
name='',
conv_name='_conv_head',
bn_name='_bn1')
pool = fluid.layers.pool2d(input=conv, pool_type='avg', global_pooling=True, use_cudnn=False)
if self._global_params.dropout_rate:
pool = fluid.layers.dropout(pool, self._global_params.dropout_rate, dropout_implementation='upscale_in_train')
param_attr, bias_attr = init_fc_layer(class_dim, '_fc')
out = fluid.layers.fc(pool, class_dim, name='_fc', param_attr=param_attr, bias_attr=bias_attr)
return out
def _drop_connect(self, inputs, prob, is_test):
if is_test:
return inputs
keep_prob = 1.0 - prob
random_tensor = keep_prob + fluid.layers.uniform_random_batch_size_like(inputs, [-1, 1, 1, 1], min=0., max=1.)
binary_tensor = fluid.layers.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def _expand_conv_norm(self, inputs, block_args, is_test, name=None):
# Expansion phase
oup = block_args.input_filters * block_args.expand_ratio # number of output channels
if block_args.expand_ratio != 1:
conv = self.conv_bn_layer(inputs,
num_filters=oup,
filter_size=1,
bn_act=None,
bn_mom=self._bn_mom,
bn_eps=self._bn_eps,
padding_type=self.padding_type,
name=name,
conv_name=name + '_expand_conv',
bn_name='_bn0')
return conv
def _depthwise_conv_norm(self, inputs, block_args, is_test, name=None):
k = block_args.kernel_size
s = block_args.stride
if isinstance(s, list) or isinstance(s, tuple):
s = s[0]
oup = block_args.input_filters * block_args.expand_ratio # number of output channels
conv = self.conv_bn_layer(inputs,
num_filters=oup,
filter_size=k,
stride=s,
num_groups=oup,
bn_act=None,
padding_type=self.padding_type,
bn_mom=self._bn_mom,
bn_eps=self._bn_eps,
name=name,
use_cudnn=False,
conv_name=name + '_depthwise_conv',
bn_name='_bn1')
return conv
def _project_conv_norm(self, inputs, block_args, is_test, name=None):
final_oup = block_args.output_filters
conv = self.conv_bn_layer(inputs,
num_filters=final_oup,
filter_size=1,
bn_act=None,
padding_type=self.padding_type,
bn_mom=self._bn_mom,
bn_eps=self._bn_eps,
name=name,
conv_name=name + '_project_conv',
bn_name='_bn2')
return conv
def conv_bn_layer(self, input, filter_size, num_filters, stride=1, num_groups=1, padding_type="SAME", conv_act=None,
bn_act='swish', use_cudnn=True, use_bn=True, bn_mom=0.9, bn_eps=1e-05, use_bias=False, name=None,
conv_name=None, bn_name=None):
conv = conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
groups=num_groups,
act=conv_act,
padding_type=padding_type,
use_cudnn=use_cudnn,
name=conv_name,
use_bias=use_bias)
if use_bn == False:
return conv
else:
bn_name = name + bn_name
param_attr, bias_attr = init_batch_norm_layer(bn_name)
return fluid.layers.batch_norm(input=conv,
act=bn_act,
momentum=bn_mom,
epsilon=bn_eps,
name=bn_name,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
param_attr=param_attr,
bias_attr=bias_attr)
def _conv_stem_norm(self, inputs, is_test):
out_channels = round_filters(32, self._global_params)
bn = self.conv_bn_layer(inputs, num_filters=out_channels, filter_size=3, stride=2, bn_act=None,
bn_mom=self._bn_mom, padding_type=self.padding_type,
bn_eps=self._bn_eps, name='', conv_name='_conv_stem', bn_name='_bn0')
return bn
def mb_conv_block(self, inputs, block_args, is_test=False, drop_connect_rate=None, name=None):
# Expansion and Depthwise Convolution
oup = block_args.input_filters * block_args.expand_ratio # number of output channels
has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)
id_skip = block_args.id_skip # skip connection and drop connect
conv = inputs
if block_args.expand_ratio != 1:
conv = fluid.layers.swish(self._expand_conv_norm(conv, block_args, is_test, name))
conv = fluid.layers.swish(self._depthwise_conv_norm(conv, block_args, is_test, name))
# Squeeze and Excitation
if has_se:
num_squeezed_channels = max(1, int(block_args.input_filters * block_args.se_ratio))
conv = self.se_block(conv, num_squeezed_channels, oup, name)
conv = self._project_conv_norm(conv, block_args, is_test, name)
# Skip connection and drop connect
input_filters, output_filters = block_args.input_filters, block_args.output_filters
if id_skip and block_args.stride == 1 and input_filters == output_filters:
if drop_connect_rate:
conv = self._drop_connect(conv, drop_connect_rate, self.is_test)
conv = fluid.layers.elementwise_add(conv, inputs)
return conv
def se_block(self, inputs, num_squeezed_channels, oup, name):
x_squeezed = fluid.layers.pool2d(
input=inputs,
pool_type='avg',
global_pooling=True,
use_cudnn=False)
x_squeezed = conv2d(x_squeezed,
num_filters=num_squeezed_channels,
filter_size=1,
use_bias=True,
padding_type=self.padding_type,
act='swish',
name=name + '_se_reduce')
x_squeezed = conv2d(x_squeezed,
num_filters=oup,
filter_size=1,
use_bias=True,
padding_type=self.padding_type,
name=name + '_se_expand')
se_out = inputs * fluid.layers.sigmoid(x_squeezed)
return se_out
def extract_features(self, inputs, is_test):
""" Returns output of the final convolution layer """
conv = fluid.layers.swish(self._conv_stem_norm(inputs, is_test=is_test))
block_args_copy = copy.deepcopy(self._blocks_args)
idx = 0
block_size = 0
for block_arg in block_args_copy:
block_arg = block_arg._replace(
input_filters=round_filters(block_arg.input_filters, self._global_params),
output_filters=round_filters(block_arg.output_filters, self._global_params),
num_repeat=round_repeats(block_arg.num_repeat, self._global_params)
)
block_size += 1
for _ in range(block_arg.num_repeat - 1):
block_size += 1
for block_args in self._blocks_args:
# Update block input and output filters based on depth multiplier.
block_args = block_args._replace(
input_filters=round_filters(block_args.input_filters, self._global_params),
output_filters=round_filters(block_args.output_filters, self._global_params),
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
)
# The first block needs to take care of stride and filter size increase.
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / block_size
conv = self.mb_conv_block(conv, block_args, is_test, drop_connect_rate, '_blocks.' + str(idx) + '.')
idx += 1
if block_args.num_repeat > 1:
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
for _ in range(block_args.num_repeat - 1):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / block_size
conv = self.mb_conv_block(conv, block_args, is_test, drop_connect_rate, '_blocks.' + str(idx) + '.')
idx += 1
return conv
def shortcut(self, input, data_residual):
return fluid.layers.elementwise_add(input, data_residual)
class BlockDecoder(object):
""" Block Decoder for readability, straight from the official TensorFlow repository """
@staticmethod
def _decode_block_string(block_string):
""" Gets a block through a string notation of arguments. """
assert isinstance(block_string, str)
ops = block_string.split('_')
options = {}
for op in ops:
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# Check stride
assert (('s' in options and len(options['s']) == 1) or
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
return BlockArgs(
kernel_size=int(options['k']),
num_repeat=int(options['r']),
input_filters=int(options['i']),
output_filters=int(options['o']),
expand_ratio=int(options['e']),
id_skip=('noskip' not in block_string),
se_ratio=float(options['se']) if 'se' in options else None,
stride=[int(options['s'][0])])
@staticmethod
def _encode_block_string(block):
"""Encodes a block to a string."""
args = [
'r%d' % block.num_repeat,
'k%d' % block.kernel_size,
's%d%d' % (block.strides[0], block.strides[1]),
'e%s' % block.expand_ratio,
'i%d' % block.input_filters,
'o%d' % block.output_filters
]
if 0 < block.se_ratio <= 1:
args.append('se%s' % block.se_ratio)
if block.id_skip is False:
args.append('noskip')
return '_'.join(args)
@staticmethod
def decode(string_list):
"""
Decodes a list of string notations to specify blocks inside the network.
:param string_list: a list of strings, each string is a notation of block
:return: a list of BlockArgs namedtuples of block args
"""
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
return blocks_args
@staticmethod
def encode(blocks_args):
"""
Encodes a list of BlockArgs to a list of strings.
:param blocks_args: a list of BlockArgs namedtuples of block args
:return: a list of strings, each string is a notation of block
"""
block_strings = []
for block in blocks_args:
block_strings.append(BlockDecoder._encode_block_string(block))
return block_strings
def EfficientNetB0(is_test=False, padding_type='SAME', override_params=None):
model = EfficientNet(name='b0', is_test=is_test, padding_type=padding_type, override_params=override_params)
return model
def EfficientNetB1(is_test=False, padding_type='SAME', override_params=None):
model = EfficientNet(name='b1', is_test=is_test, padding_type=padding_type, override_params=override_params)
return model
def EfficientNetB2(is_test=False, padding_type='SAME', override_params=None):
model = EfficientNet(name='b2', is_test=is_test, padding_type=padding_type, override_params=override_params)
return model
def EfficientNetB3(is_test=False, padding_type='SAME', override_params=None):
model = EfficientNet(name='b3', is_test=is_test, padding_type=padding_type, override_params=override_params)
return model
def EfficientNetB4(is_test=False, padding_type='SAME', override_params=None):
model = EfficientNet(name='b4', is_test=is_test, padding_type=padding_type, override_params=override_params)
return model
def EfficientNetB5(is_test=False, padding_type='SAME', override_params=None):
model = EfficientNet(name='b5', is_test=is_test, padding_type=padding_type, override_params=override_params)
return model
def EfficientNetB6(is_test=False, padding_type='SAME', override_params=None):
model = EfficientNet(name='b6', is_test=is_test, padding_type=padding_type, override_params=override_params)
return model
def EfficientNetB7(is_test=False, padding_type='SAME', override_params=None):
model = EfficientNet(name='b7', is_test=is_test, padding_type=padding_type, override_params=override_params)
return model
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import math
import warnings
def initial_type(name,
input,
op_type,
fan_out,
init="google",
use_bias=False,
filter_size=0,
stddev=0.02):
if init == "kaiming":
if op_type == 'conv':
fan_in = input.shape[1] * filter_size * filter_size
elif op_type == 'deconv':
fan_in = fan_out * filter_size * filter_size
else:
if len(input.shape) > 2:
fan_in = input.shape[1] * input.shape[2] * input.shape[3]
else:
fan_in = input.shape[1]
bound = 1 / math.sqrt(fan_in)
param_attr = fluid.ParamAttr(
name=name + "_weights",
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + '_offset',
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
else:
bias_attr = False
elif init == 'google':
n = filter_size * filter_size * fan_out
param_attr = fluid.ParamAttr(
name=name + "_weights",
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=math.sqrt(2.0 / n)))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + "_offset", initializer=fluid.initializer.Constant(0.0))
else:
bias_attr = False
else:
param_attr = fluid.ParamAttr(
name=name + "_weights",
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + "_offset", initializer=fluid.initializer.Constant(0.0))
else:
bias_attr = False
return param_attr, bias_attr
def cal_padding(img_size, stride, filter_size, dilation=1):
"""Calculate padding size."""
if img_size % stride == 0:
out_size = max(filter_size - stride, 0)
else:
out_size = max(filter_size - (img_size % stride), 0)
return out_size // 2, out_size - out_size // 2
def init_batch_norm_layer(name="batch_norm"):
param_attr = fluid.ParamAttr(
name=name + '_scale', initializer=fluid.initializer.Constant(1.0))
bias_attr = fluid.ParamAttr(
name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0))
return param_attr, bias_attr
def init_fc_layer(fout, name='fc'):
n = fout # fan-out
init_range = 1.0 / math.sqrt(n)
param_attr = fluid.ParamAttr(
name=name + '_weights', initializer=fluid.initializer.UniformInitializer(
low=-init_range, high=init_range))
bias_attr = fluid.ParamAttr(
name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0))
return param_attr, bias_attr
def norm_layer(input, norm_type='batch_norm', name=None):
if norm_type == 'batch_norm':
param_attr = fluid.ParamAttr(
name=name + '_weights', initializer=fluid.initializer.Constant(1.0))
bias_attr = fluid.ParamAttr(
name=name + '_offset', initializer=fluid.initializer.Constant(value=0.0))
return fluid.layers.batch_norm(
input,
param_attr=param_attr,
bias_attr=bias_attr,
moving_mean_name=name + '_mean',
moving_variance_name=name + '_variance')
elif norm_type == 'instance_norm':
helper = fluid.layer_helper.LayerHelper("instance_norm", **locals())
dtype = helper.input_dtype()
epsilon = 1e-5
mean = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True)
var = fluid.layers.reduce_mean(
fluid.layers.square(input - mean), dim=[2, 3], keep_dim=True)
if name is not None:
scale_name = name + "_scale"
offset_name = name + "_offset"
scale_param = fluid.ParamAttr(
name=scale_name,
initializer=fluid.initializer.Constant(1.0),
trainable=True)
offset_param = fluid.ParamAttr(
name=offset_name,
initializer=fluid.initializer.Constant(0.0),
trainable=True)
scale = helper.create_parameter(
attr=scale_param, shape=input.shape[1:2], dtype=dtype)
offset = helper.create_parameter(
attr=offset_param, shape=input.shape[1:2], dtype=dtype)
tmp = fluid.layers.elementwise_mul(x=(input - mean), y=scale, axis=1)
tmp = tmp / fluid.layers.sqrt(var + epsilon)
tmp = fluid.layers.elementwise_add(tmp, offset, axis=1)
return tmp
else:
raise NotImplementedError("norm tyoe: [%s] is not support" % norm_type)
def conv2d(input,
num_filters=64,
filter_size=7,
stride=1,
stddev=0.02,
padding=0,
groups=None,
name="conv2d",
norm=None,
act=None,
relufactor=0.0,
use_bias=False,
padding_type=None,
initial="normal",
use_cudnn=True):
if padding != 0 and padding_type != None:
warnings.warn(
'padding value and padding type are set in the same time, and the final padding width and padding height are computed by padding_type'
)
param_attr, bias_attr = initial_type(
name=name,
input=input,
op_type='conv',
fan_out=num_filters,
init=initial,
use_bias=use_bias,
filter_size=filter_size,
stddev=stddev)
def get_padding(filter_size, stride=1, dilation=1):
padding = ((stride - 1) + dilation * (filter_size - 1)) // 2
return padding
need_crop = False
if padding_type == "SAME":
top_padding, bottom_padding = cal_padding(input.shape[2], stride,
filter_size)
left_padding, right_padding = cal_padding(input.shape[2], stride,
filter_size)
height_padding = bottom_padding
width_padding = right_padding
if top_padding != bottom_padding or left_padding != right_padding:
height_padding = top_padding + stride
width_padding = left_padding + stride
need_crop = True
padding = [height_padding, width_padding]
elif padding_type == "VALID":
height_padding = 0
width_padding = 0
padding = [height_padding, width_padding]
elif padding_type == "DYNAMIC":
padding = get_padding(filter_size, stride)
else:
padding = padding
conv = fluid.layers.conv2d(
input,
num_filters,
filter_size,
groups=groups,
name=name,
stride=stride,
padding=padding,
use_cudnn=use_cudnn,
param_attr=param_attr,
bias_attr=bias_attr)
if need_crop:
conv = conv[:, :, 1:, 1:]
if norm is not None:
conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm")
if act == 'relu':
conv = fluid.layers.relu(conv, name=name + '_relu')
elif act == 'leaky_relu':
conv = fluid.layers.leaky_relu(
conv, alpha=relufactor, name=name + '_leaky_relu')
elif act == 'tanh':
conv = fluid.layers.tanh(conv, name=name + '_tanh')
elif act == 'sigmoid':
conv = fluid.layers.sigmoid(conv, name=name + '_sigmoid')
elif act == 'swish':
conv = fluid.layers.swish(conv, name=name + '_swish')
elif act == None:
conv = conv
else:
raise NotImplementedError("activation: [%s] is not support" %act)
return conv
\ No newline at end of file
......@@ -18,11 +18,12 @@ import random
import functools
import numpy as np
import cv2
import io
import signal
import paddle
import paddle.fluid as fluid
from utils.autoaugment import ImageNetPolicy
from PIL import Image
policy = None
random.seed(0)
np.random.seed(0)
......@@ -200,7 +201,6 @@ def create_mixup_reader(settings, rd):
return mixup_reader
def process_image(sample, settings, mode, color_jitter, rotate):
""" process_image """
......@@ -215,7 +215,7 @@ def process_image(sample, settings, mode, color_jitter, rotate):
if rotate:
img = rotate_image(img)
if crop_size > 0:
img = random_crop(img, crop_size, settings)
img = random_crop(img, crop_size, settings, interpolation=settings.interpolation)
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
......@@ -223,10 +223,18 @@ def process_image(sample, settings, mode, color_jitter, rotate):
else:
if crop_size > 0:
target_size = settings.resize_short_size
img = resize_short(img, target_size)
img = resize_short(img, target_size, interpolation=settings.interpolation)
img = crop_image(img, target_size=crop_size, center=True)
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img = img[:, :, ::-1]
if 'use_aa' in settings and settings.use_aa and mode == 'train':
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
img = policy(img)
img = np.asarray(img)
img = img.astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
......@@ -294,6 +302,11 @@ def train(settings):
assert os.path.isfile(
file_list), "{} doesn't exist, please check data list path".format(
file_list)
if 'use_aa' in settings and settings.use_aa:
global policy
policy = ImageNetPolicy()
reader = _reader_creator(
settings,
file_list,
......@@ -317,6 +330,7 @@ def val(settings):
Returns:
eval reader
"""
file_list = os.path.join(settings.data_dir, 'val_list.txt')
assert os.path.isfile(
file_list), "{} doesn't exist, please check data list path".format(
......
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.96
python -u train.py \
--model=EfficientNet \
--batch_size=512 \
--test_batch_size=128 \
--total_images=1281167 \
--class_dim=1000 \
--image_shape=3,224,224 \
--resize_short_size=256 \
--model_save_dir=output/ \
--lr_strategy=exponential_decay_warmup \
--lr=0.032 \
--num_epochs=360 \
--l2_decay=1e-5 \
--use_label_smoothing=True \
--label_smoothing_epsilon=0.1 \
--use_ema=True \
--ema_decay=0.9999 \
--drop_connect_rate=0.1 \
--padding_type="SAME" \
--interpolation=2 \
--use_aa=True
......@@ -20,9 +20,6 @@ import os
import numpy as np
import time
import sys
import functools
import math
def set_paddle_flags(flags):
for key, value in flags.items():
......@@ -38,10 +35,6 @@ set_paddle_flags({
'FLAGS_fraction_of_gpu_memory_to_use': 0.98
})
import argparse
import functools
import subprocess
import paddle
import paddle.fluid as fluid
import reader
......@@ -63,7 +56,13 @@ def build_program(is_train, main_prog, startup_prog, args):
train mode: [Loss, global_lr, py_reader]
test mode: [Loss, py_reader]
"""
model = models.__dict__[args.model]()
if args.model.startswith('EfficientNet'):
is_test = False if is_train else True
override_params = {"drop_connect_rate": args.drop_connect_rate}
padding_type = args.padding_type
model = models.__dict__[args.model](is_test=is_test, override_params=override_params, padding_type=padding_type)
else:
model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog):
if args.random_seed:
main_prog.random_seed = args.random_seed
......@@ -79,9 +78,50 @@ def build_program(is_train, main_prog, startup_prog, args):
global_lr = optimizer._global_learning_rate()
global_lr.persistable = True
loss_out.append(global_lr)
if args.use_ema:
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter()
ema = ExponentialMovingAverage(args.ema_decay, thres_steps=global_steps)
ema.update()
loss_out.append(ema)
loss_out.append(py_reader)
return loss_out
def validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record):
test_batch_time_record = []
test_batch_metrics_record = []
test_batch_id = 0
test_py_reader.start()
try:
while True:
t1 = time.time()
test_batch_metrics = exe.run(program=test_prog,
fetch_list=test_fetch_list)
t2 = time.time()
test_batch_elapse = t2 - t1
test_batch_time_record.append(test_batch_elapse)
test_batch_metrics_avg = np.mean(
np.array(test_batch_metrics), axis=1)
test_batch_metrics_record.append(test_batch_metrics_avg)
print_info(pass_id, test_batch_id, args.print_step,
test_batch_metrics_avg, test_batch_elapse, "batch")
sys.stdout.flush()
test_batch_id += 1
except fluid.core.EOFException:
test_py_reader.reset()
#train_epoch_time_avg = np.mean(np.array(train_batch_time_record))
train_epoch_metrics_avg = np.mean(
np.array(train_batch_metrics_record), axis=0)
test_epoch_time_avg = np.mean(np.array(test_batch_time_record))
test_epoch_metrics_avg = np.mean(
np.array(test_batch_metrics_record), axis=0)
print_info(pass_id, 0, 0,
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
test_epoch_time_avg, "epoch")
def train(args):
"""Train model
......@@ -99,7 +139,11 @@ def train(args):
startup_prog=startup_prog,
args=args)
train_py_reader = train_out[-1]
train_fetch_vars = train_out[:-1]
if args.use_ema:
train_fetch_vars = train_out[:-2]
ema = train_out[-2]
else:
train_fetch_vars = train_out[:-1]
train_fetch_list = [var.name for var in train_fetch_vars]
......@@ -143,11 +187,8 @@ def train(args):
for pass_id in range(args.num_epochs):
train_batch_id = 0
test_batch_id = 0
train_batch_time_record = []
test_batch_time_record = []
train_batch_metrics_record = []
test_batch_metrics_record = []
train_py_reader.start()
......@@ -171,38 +212,13 @@ def train(args):
except fluid.core.EOFException:
train_py_reader.reset()
test_py_reader.start()
try:
while True:
t1 = time.time()
test_batch_metrics = exe.run(program=test_prog,
fetch_list=test_fetch_list)
t2 = time.time()
test_batch_elapse = t2 - t1
test_batch_time_record.append(test_batch_elapse)
test_batch_metrics_avg = np.mean(
np.array(test_batch_metrics), axis=1)
test_batch_metrics_record.append(test_batch_metrics_avg)
if args.use_ema:
print('ExponentialMovingAverage validate start...')
with ema.apply(exe):
validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record)
print('ExponentialMovingAverage validate over!')
print_info(pass_id, test_batch_id, args.print_step,
test_batch_metrics_avg, test_batch_elapse, "batch")
sys.stdout.flush()
test_batch_id += 1
except fluid.core.EOFException:
test_py_reader.reset()
train_epoch_time_avg = np.mean(np.array(train_batch_time_record))
train_epoch_metrics_avg = np.mean(
np.array(train_batch_metrics_record), axis=0)
test_epoch_time_avg = np.mean(np.array(test_batch_time_record))
test_epoch_metrics_avg = np.mean(
np.array(test_batch_metrics_record), axis=0)
print_info(pass_id, 0, 0,
list(train_epoch_metrics_avg) + list(test_epoch_metrics_avg),
0, "epoch")
validate(args, test_py_reader, exe, test_prog, test_fetch_list, pass_id, train_batch_metrics_record)
#For now, save model per epoch.
if pass_id % args.save_step == 0:
save_model(args, exe, train_prog, pass_id)
......
......@@ -11,5 +11,5 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from .optimizer import cosine_decay, lr_warmup, cosine_decay_with_warmup, Optimizer, create_optimizer
from .utility import add_arguments, print_arguments, parse_args, check_gpu, check_args, init_model, save_model, create_pyreader, print_info, best_strategy_compiled, init_model, save_model
from .optimizer import cosine_decay, lr_warmup, cosine_decay_with_warmup, exponential_decay_with_warmup, Optimizer, create_optimizer
from .utility import add_arguments, print_arguments, parse_args, check_gpu, check_args, init_model, save_model, create_pyreader, print_info, best_strategy_compiled, init_model, save_model, ExponentialMovingAverage
"""
This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
"""
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
class ImageNetPolicy(object):
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy(object):
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy(object):
""" Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
return img
......@@ -67,6 +67,33 @@ def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120):
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
return lr
def exponential_decay_with_warmup(learning_rate, step_each_epoch, decay_epochs, decay_rate=0.97, warm_up_epoch=5.0):
"""Applies exponential decay to the learning rate.
"""
global_step = _decay_step_counter()
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True)
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < warmup_epoch):
decayed_lr = learning_rate * (global_step / (step_each_epoch * warmup_epoch))
fluid.layers.assign(input=decayed_lr, output=lr)
with switch.default():
div_res = (global_step - warmup_epoch * step_each_epoch) / decay_epochs
div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate ** div_res)
fluid.layers.assign(input=decayed_lr, output=lr)
return lr
def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
""" Applies linear learning rate warmup for distributed training
......@@ -123,8 +150,11 @@ class Optimizer(object):
self.momentum_rate = args.momentum_rate
self.step_epochs = args.step_epochs
self.num_epochs = args.num_epochs
self.warm_up_epochs = args.warm_up_epochs
self.decay_epochs = args.decay_epochs
self.decay_rate = args.decay_rate
self.total_images = args.total_images
self.step = int(math.ceil(float(self.total_images) / self.batch_size))
def piecewise_decay(self):
......@@ -176,6 +206,28 @@ class Optimizer(object):
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def exponential_decay_warmup(self):
"""exponential decay with warmup
Returns:
a exponential_decay_with_warmup optimizer
"""
learning_rate = exponential_decay_with_warmup(
learning_rate=self.lr,
step_each_epoch=self.step,
decay_epochs=self.step * self.decay_epochs,
decay_rate=self.decay_rate,
warm_up_epoch=self.warm_up_epochs)
optimizer = fluid.optimizer.RMSProp(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay),
momentum=self.momentum_rate,
rho=0.9,
epsilon=0.001
)
return optimizer
def linear_decay(self):
"""linear decay with Momentum optimizer
......
......@@ -29,7 +29,9 @@ import signal
import paddle
import paddle.fluid as fluid
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.framework import Program, program_guard, name_scope, default_main_program
from paddle.fluid import unique_name, layers
def print_arguments(args):
"""Print argparse's arguments.
......@@ -103,7 +105,12 @@ def parse_args():
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 1e-4, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('warm_up_epochs', float, 5.0, "The value of warm up epochs")
add_arg('decay_epochs', float, 2.4, "Decay epochs of exponential decay learning rate scheduler")
add_arg('decay_rate', float, 0.97, "Decay rate of exponential decay learning rate scheduler")
add_arg('drop_connect_rate', float, 0.2, "The value of drop connect rate")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# READER AND PREPROCESS
add_arg('lower_scale', float, 0.08, "The value of lower_scale in ramdom_crop")
add_arg('lower_ratio', float, 3./4., "The value of lower_ratio in ramdom_crop")
......@@ -115,6 +122,7 @@ def parse_args():
add_arg('reader_thread', int, 8, "The number of multi thread reader")
add_arg('reader_buf_size', int, 2048, "The buf size of multi thread reader")
add_arg('interpolation', int, None, "The interpolation mode")
add_arg('use_aa', bool, False, "Whether to use auto augment")
parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456, 0.406], help="The mean of input image data")
parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data")
......@@ -127,6 +135,9 @@ def parse_args():
#NOTE: (2019/08/08) temporary disable use_distill
#add_arg('use_distill', bool, False, "Whether to use distill")
add_arg('random_seed', int, None, "random seed")
add_arg('use_ema', bool, False, "Whether to use ExponentialMovingAverage.")
add_arg('ema_decay', float, 0.9999, "The value of ema decay rate")
add_arg('padding_type', str, "SAME", "Padding type of convolution")
# yapf: enable
args = parser.parse_args()
......@@ -170,7 +181,7 @@ def check_args(args):
# check learning rate strategy
lr_strategy_list = [
"piecewise_decay", "cosine_decay", "linear_decay", "cosine_decay_warmup"
"piecewise_decay", "cosine_decay", "linear_decay", "cosine_decay_warmup", "exponential_decay_warmup"
]
if args.lr_strategy not in lr_strategy_list:
warnings.warn(
......@@ -186,6 +197,11 @@ def check_args(args):
0, 1, 2, 3, 4
], "Wrong interpolation, please set:\n0: cv2.INTER_NEAREST\n1: cv2.INTER_LINEAR\n2: cv2.INTER_CUBIC\n3: cv2.INTER_AREA\n4: cv2.INTER_LANCZOS4"
if args.padding_type:
assert args.padding_type in [
"SAME", "VALID", "DYNAMIC"
], "Wrong padding_type, please set:\nSAME\nVALID\nDYNAMIC"
assert args.checkpoint is None or args.pretrained_model is None, "Do not init model by checkpoint and pretrained_model both."
# check pretrained_model path for loading
......@@ -381,3 +397,141 @@ def best_strategy_compiled(args, program, loss):
exec_strategy=exec_strategy)
return compiled_program
class ExponentialMovingAverage(object):
def __init__(self, decay=0.999, thres_steps=None, zero_debias=False, name=None):
self._decay = decay
self._thres_steps = thres_steps
self._name = name if name is not None else ''
self._decay_var = self._get_ema_decay()
self._params_tmps = []
for param in default_main_program().global_block().all_parameters():
if param.do_model_average != False:
tmp = param.block.create_var(
name=unique_name.generate(".".join(
[self._name + param.name, 'ema_tmp'])),
dtype=param.dtype,
persistable=False,
stop_gradient=True)
self._params_tmps.append((param, tmp))
self._ema_vars = {}
for param, tmp in self._params_tmps:
with param.block.program._optimized_guard(
[param, tmp]), name_scope('moving_average'):
self._ema_vars[param.name] = self._create_ema_vars(param)
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
decay_pow = self._get_decay_pow(block)
for param, tmp in self._params_tmps:
param = block._clone_variable(param)
tmp = block._clone_variable(tmp)
ema = block._clone_variable(self._ema_vars[param.name])
layers.assign(input=param, output=tmp)
# bias correction
if zero_debias:
ema = ema / (1.0 - decay_pow)
layers.assign(input=ema, output=param)
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
for param, tmp in self._params_tmps:
tmp = block._clone_variable(tmp)
param = block._clone_variable(param)
layers.assign(input=tmp, output=param)
def _get_ema_decay(self):
with default_main_program()._lr_schedule_guard():
decay_var = layers.tensor.create_global_var(
shape=[1],
value=self._decay,
dtype='float32',
persistable=True,
name="scheduled_ema_decay_rate")
if self._thres_steps is not None:
decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0)
with layers.control_flow.Switch() as switch:
with switch.case(decay_t < self._decay):
layers.tensor.assign(decay_t, decay_var)
with switch.default():
layers.tensor.assign(
np.array(
[self._decay], dtype=np.float32),
decay_var)
return decay_var
def _get_decay_pow(self, block):
global_steps = layers.learning_rate_scheduler._decay_step_counter()
decay_var = block._clone_variable(self._decay_var)
decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1)
return decay_pow_acc
def _create_ema_vars(self, param):
param_ema = layers.create_global_var(
name=unique_name.generate(self._name + param.name + '_ema'),
shape=param.shape,
value=0.0,
dtype=param.dtype,
persistable=True)
return param_ema
def update(self):
"""
Update Exponential Moving Average. Should only call this method in
train program.
"""
param_master_emas = []
for param, tmp in self._params_tmps:
with param.block.program._optimized_guard(
[param, tmp]), name_scope('moving_average'):
param_ema = self._ema_vars[param.name]
if param.name + '.master' in self._ema_vars:
master_ema = self._ema_vars[param.name + '.master']
param_master_emas.append([param_ema, master_ema])
else:
ema_t = param_ema * self._decay_var + param * (
1 - self._decay_var)
layers.assign(input=ema_t, output=param_ema)
# for fp16 params
for param_ema, master_ema in param_master_emas:
default_main_program().global_block().append_op(
type="cast",
inputs={"X": master_ema},
outputs={"Out": param_ema},
attrs={
"in_dtype": master_ema.dtype,
"out_dtype": param_ema.dtype
})
@signature_safe_contextmanager
def apply(self, executor, need_restore=True):
"""
Apply moving average to parameters for evaluation.
Args:
executor (Executor): The Executor to execute applying.
need_restore (bool): Whether to restore parameters after applying.
"""
executor.run(self.apply_program)
try:
yield
finally:
if need_restore:
self.restore(executor)
def restore(self, executor):
"""Restore parameters.
Args:
executor (Executor): The Executor to execute restoring.
"""
executor.run(self.restore_program)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册