提交 ae89990b 编写于 作者: C chajchaj 提交者: ruri

Add MobileNet v1 and v2 dygraph code (#4188)

上级 8aeb4413
**模型简介**
图像分类是计算机视觉的重要领域,它的目标是将图像分类到预定义的标签。CNN模型在图像分类领域取得了突破的成果,同时模型复杂度也在不断增加。MobileNet是一种小巧而高效CNN模型,本文介绍如何使PaddlePaddle的动态图MobileNet进行图像分类。
**代码结构**
├── run_mul_v1.sh # 多卡训练启动脚本_v1
├── run_mul_v2.sh # 多卡训练启动脚本_v2
├── run_sing_v1.sh # 单卡训练启动脚本_v1
├── run_sing_v2.sh # 单卡训练启动脚本_v2
├── train.py # 训练入口
├── mobilenet_v1.py # 网络结构v1
├── mobilenet_v2.py # 网络结构v2
├── reader.py # 数据reader
├── utils # 基础工具目录
**数据准备**
请参考:https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification
**模型训练**
若使用4卡训练,启动方式如下:
bash run_mul_v1.sh
bash run_mul_v2.sh
若使用单卡训练,启动方式如下:
bash run_sing_v1.sh
bash run_sing_v2.sh
**模型精度**
Model Top-1 Top-5
MobileNetV1 0.707 0.895
MobileNetV2 0.626 0.845
**参考论文**
MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications, Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
MobileNetV2: Inverted Residuals and Linear Bottlenecks, Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import sys
import numpy as np
import argparse
import ast
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import math
import sys
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
name_scope,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
super(ConvBNLayer, self).__init__(name_scope)
self._conv = Conv2D(
self.full_name(),
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
self.full_name(),
num_filters,
act=act,
param_attr=ParamAttr(name="_bn" + "_scale"),
bias_attr=ParamAttr(name="_bn" + "_offset"),
moving_mean_name="_bn" + '_mean',
moving_variance_name="_bn" + '_variance')
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DepthwiseSeparable(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
super(DepthwiseSeparable, self).__init__(name_scope)
self._depthwise_conv = ConvBNLayer(
name_scope="dw",
num_filters=int(num_filters1 * scale),
filter_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False)
self._pointwise_conv = ConvBNLayer(
name_scope="sep",
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
y = self._pointwise_conv(y)
return y
class MobileNetV1(fluid.dygraph.Layer):
def __init__(self, name_scope, scale=1.0, class_dim=102):
super(MobileNetV1, self).__init__(name_scope)
self.scale = scale
self.dwsl = []
self.conv1 = ConvBNLayer(
name_scope="conv1",
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
dws21 = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv2_1",
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale),
name="conv2_1")
self.dwsl.append(dws21)
dws22 = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv2_2",
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=scale),
name="conv2_2")
self.dwsl.append(dws22)
dws31 = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv3_1",
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale),
name="conv3_1")
self.dwsl.append(dws31)
dws32 = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv3_2",
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale),
name="conv3_2")
self.dwsl.append(dws32)
dws41 = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv4_1",
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale),
name="conv4_1")
self.dwsl.append(dws41)
dws42 = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv4_2",
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale),
name="conv4_2")
self.dwsl.append(dws42)
for i in range(5):
tmp = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv5_" + str(i + 1),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale),
name="conv5_" + str(i + 1))
self.dwsl.append(tmp)
dws56 = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv5_6",
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale),
name="conv5_6")
self.dwsl.append(dws56)
dws6 = self.add_sublayer(
sublayer=DepthwiseSeparable(
name_scope="conv6",
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale),
name="conv6")
self.dwsl.append(dws6)
self.pool2d_avg = Pool2D(
name_scope="pool", pool_type='avg', global_pooling=True)
self.out = FC(name_scope="fc",
size=class_dim,
param_attr=ParamAttr(
initializer=MSRA(),
name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
def forward(self, inputs):
y = self.conv1(inputs)
idx = 0
for dws in self.dwsl:
y = dws(y)
y = self.pool2d_avg(y)
y = self.out(y)
return y
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import time
import sys
import sys
import numpy as np
import argparse
import ast
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import math
import sys
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
name=None,
use_cudnn=True):
super(ConvBNLayer, self).__init__(name)
tmp_param = ParamAttr(name=name + "_weights")
self._conv = Conv2D(
self.full_name(),
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=tmp_param,
bias_attr=False)
self._batch_norm = BatchNorm(
self.full_name(),
num_filters,
param_attr=ParamAttr(name=name + "_bn" + "_scale"),
bias_attr=ParamAttr(name=name + "_bn" + "_offset"),
moving_mean_name=name + "_bn" + '_mean',
moving_variance_name=name + "_bn" + '_variance')
def forward(self, inputs, if_act=True):
y = self._conv(inputs)
y = self._batch_norm(y)
if if_act:
y = fluid.layers.relu6(y)
return y
class InvertedResidualUnit(fluid.dygraph.Layer):
def __init__(self,
num_in_filter,
num_filters,
stride,
filter_size,
padding,
expansion_factor,
name=None):
super(InvertedResidualUnit, self).__init__(name)
num_expfilter = int(round(num_in_filter * expansion_factor))
self._expand_conv = ConvBNLayer(
name=name + "_expand",
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
self._bottleneck_conv = ConvBNLayer(
name=name + "_dwise",
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
use_cudnn=False)
self._linear_conv = ConvBNLayer(
name=name + "_linear",
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
def forward(self, inputs, ifshortcut):
y = self._expand_conv(inputs, if_act=True)
y = self._bottleneck_conv(y, if_act=True)
y = self._linear_conv(y, if_act=False)
if ifshortcut:
y = fluid.layers.elementwise_add(inputs, y)
return y
class InvresiBlocks(fluid.dygraph.Layer):
def __init__(self, in_c, t, c, n, s, name=None):
super(InvresiBlocks, self).__init__(name)
self._first_block = InvertedResidualUnit(
name=name + "_1",
num_in_filter=in_c,
num_filters=c,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t)
self._inv_blocks = []
for i in range(1, n):
tmp = self.add_sublayer(
sublayer=InvertedResidualUnit(
name=name + "_" + str(i + 1),
num_in_filter=c,
num_filters=c,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t),
name=name + "_" + str(i + 1))
self._inv_blocks.append(tmp)
def forward(self, inputs):
y = self._first_block(inputs, ifshortcut=False)
for inv_block in self._inv_blocks:
y = inv_block(y, ifshortcut=True)
return y
class MobileNetV2(fluid.dygraph.Layer):
def __init__(self, name, class_dim=1000, scale=1.0):
super(MobileNetV2, self).__init__(name)
self.scale = scale
self.class_dim = class_dim
bottleneck_params_list = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
#1. conv1
self._conv1 = ConvBNLayer(
name="conv1_1",
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1)
#2. bottleneck sequences
self._invl = []
i = 1
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
i += 1
tmp = self.add_sublayer(
sublayer=InvresiBlocks(
name='conv' + str(i),
in_c=in_c,
t=t,
c=int(c * scale),
n=n,
s=s),
name='conv' + str(i))
self._invl.append(tmp)
in_c = int(c * scale)
#3. last_conv
self._conv9 = ConvBNLayer(
name="conv9",
num_filters=int(1280 * scale) if scale > 1.0 else 1280,
filter_size=1,
stride=1,
padding=0)
#4. pool
self._pool2d_avg = Pool2D(
name_scope="pool", pool_type='avg', global_pooling=True)
#5. fc
tmp_param = ParamAttr(name="fc10_weights")
self._fc = FC(name_scope="fc",
size=class_dim,
param_attr=tmp_param,
bias_attr=ParamAttr(name="fc10_offset"))
def forward(self, inputs):
y = self._conv1(inputs, if_act=True)
for inv in self._invl:
y = inv(y)
y = self._conv9(y, if_act=True)
y = self._pool2d_avg(y)
y = self._fc(y)
return y
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import sys
import os
import math
import random
import functools
import numpy as np
import cv2
import paddle
from paddle import fluid
from utils.autoaugment import ImageNetPolicy
from PIL import Image
policy = None
random.seed(0)
np.random.seed(0)
def rotate_image(img):
"""rotate image
Args:
img: image data
Returns:
rotated image data
"""
(h, w) = img.shape[:2]
center = (w / 2, h / 2)
angle = np.random.randint(-10, 11)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img, M, (w, h))
return rotated
def random_crop(img, size, settings, scale=None, ratio=None,
interpolation=None):
"""random crop image
Args:
img: image data
size: crop size
settings: arguments
scale: scale parameter
ratio: ratio parameter
Returns:
random cropped image data
"""
lower_scale = settings.lower_scale
lower_ratio = settings.lower_ratio
upper_ratio = settings.upper_ratio
scale = [lower_scale, 1.0] if scale is None else scale
ratio = [lower_ratio, upper_ratio] if ratio is None else ratio
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.shape[0]) / img.shape[1]) / (h**2),
(float(img.shape[1]) / img.shape[0]) / (w**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.shape[0] * img.shape[1] * np.random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, img.shape[0] - h + 1)
j = np.random.randint(0, img.shape[1] - w + 1)
img = img[i:i + h, j:j + w, :]
if interpolation:
resized = cv2.resize(img, (size, size), interpolation=interpolation)
else:
resized = cv2.resize(img, (size, size))
return resized
#NOTE:(2019/08/08) distort color func is not implemented
def distort_color(img):
"""distort image color
Args:
img: image data
Returns:
distorted color image data
"""
return img
def resize_short(img, target_size, interpolation=None):
"""resize image
Args:
img: image data
target_size: resize short target size
interpolation: interpolation mode
Returns:
resized image data
"""
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
if interpolation:
resized = cv2.resize(
img, (resized_width, resized_height), interpolation=interpolation)
else:
resized = cv2.resize(img, (resized_width, resized_height))
return resized
def crop_image(img, target_size, center):
"""crop image
Args:
img: images data
target_size: crop target size
center: crop mode
Returns:
img: cropped image data
"""
height, width = img.shape[:2]
size = target_size
if center == True:
w_start = (width - size) // 2
h_start = (height - size) // 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
return img
def create_mixup_reader(settings, rd):
"""
"""
class context:
tmp_mix = []
tmp_l1 = []
tmp_l2 = []
tmp_lam = []
alpha = settings.mixup_alpha
def fetch_data():
for item in rd():
yield item
def mixup_data():
for data_list in fetch_data():
if alpha > 0.:
lam = np.random.beta(alpha, alpha)
else:
lam = 1.
l1 = np.array(data_list)
l2 = np.random.permutation(l1)
mixed_l = [
l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1))
]
yield (mixed_l, l1, l2, lam)
def mixup_reader():
for context.tmp_mix, context.tmp_l1, context.tmp_l2, context.tmp_lam in mixup_data(
):
for i in range(len(context.tmp_mix)):
mixed_l = context.tmp_mix[i]
l1 = context.tmp_l1[i]
l2 = context.tmp_l2[i]
lam = context.tmp_lam
yield (mixed_l, int(l1[1]), int(l2[1]), float(lam))
return mixup_reader
def process_image(sample, settings, mode, color_jitter, rotate):
""" process_image """
mean = settings.image_mean
std = settings.image_std
crop_size = settings.crop_size
img_path = sample[0]
img = cv2.imread(img_path)
if mode == 'train':
if rotate:
img = rotate_image(img)
if crop_size > 0:
img = random_crop(
img, crop_size, settings, interpolation=settings.interpolation)
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
img = img[:, ::-1, :]
else:
if crop_size > 0:
target_size = settings.resize_short_size
img = resize_short(
img, target_size, interpolation=settings.interpolation)
img = crop_image(img, target_size=crop_size, center=True)
img = img[:, :, ::-1]
if 'use_aa' in settings and settings.use_aa and mode == 'train':
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
img = policy(img)
img = np.asarray(img)
img = img.astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'val':
return (img, sample[1])
elif mode == 'test':
return (img, )
def process_batch_data(input_data, settings, mode, color_jitter, rotate):
batch_data = []
for sample in input_data:
if os.path.isfile(sample[0]):
batch_data.append(
process_image(sample, settings, mode, color_jitter, rotate))
else:
print("File not exist : %s" % sample[0])
return batch_data
class ImageNetReader:
def __init__(self, seed=None):
self.shuffle_seed = seed
def set_shuffle_seed(self, seed):
assert isinstance(seed, int), "shuffle seed must be int"
self.shuffle_seed = seed
def _reader_creator(self,
settings,
file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=None):
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if mode == 'test':
batch_size = 1
else:
batch_size = settings.batch_size / paddle.fluid.core.get_cuda_device_count(
)
def reader():
def read_file_list():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if mode != "test" and len(full_lines) < settings.batch_size:
print(
"Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!"
.format(len(full_lines), settings.batch_size))
os._exit(1)
if num_trainers > 1 and mode == "train":
assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!"
np.random.RandomState(self.shuffle_seed).shuffle(
full_lines)
elif shuffle:
assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!"
np.random.RandomState(self.shuffle_seed).shuffle(
full_lines)
batch_data = []
for line in full_lines:
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
batch_data.append([img_path, int(label)])
if len(batch_data) == batch_size:
if mode == 'train' or mode == 'val' or mode == 'test':
yield batch_data
batch_data = []
return read_file_list
data_reader = reader()
if mode == 'train' and num_trainers > 1:
assert self.shuffle_seed is not None, \
"If num_trainers > 1, the shuffle_seed must be set, because " \
"the order of batch data generated by reader " \
"must be the same in the respective processes."
data_reader = paddle.fluid.contrib.reader.distributed_batch_reader(
data_reader)
mapper = functools.partial(
process_batch_data,
settings=settings,
mode=mode,
color_jitter=color_jitter,
rotate=rotate)
ret = fluid.io.xmap_readers(
mapper,
data_reader,
settings.reader_thread,
settings.reader_buf_size,
order=False)
return ret
def train(self, settings):
"""Create a reader for trainning
Args:
settings: arguments
Returns:
train reader
"""
file_list = os.path.join(settings.data_dir, 'train_list.txt')
assert os.path.isfile(
file_list), "{} doesn't exist, please check data list path".format(
file_list)
if 'use_aa' in settings and settings.use_aa:
global policy
policy = ImageNetPolicy()
reader = self._reader_creator(
settings,
file_list,
'train',
shuffle=True,
color_jitter=False,
rotate=False,
data_dir=settings.data_dir)
if settings.use_mixup == True:
reader = create_mixup_reader(settings, reader)
reader = fluid.io.batch(
reader,
batch_size=int(settings.batch_size /
paddle.fluid.core.get_cuda_device_count()),
drop_last=True)
return reader
def val(self, settings):
"""Create a reader for eval
Args:
settings: arguments
Returns:
eval reader
"""
file_list = os.path.join(settings.data_dir, 'val_list.txt')
assert os.path.isfile(
file_list), "{} doesn't exist, please check data list path".format(
file_list)
return self._reader_creator(
settings,
file_list,
'val',
shuffle=False,
data_dir=settings.data_dir)
def test(self, settings):
"""Create a reader for testing
Args:
settings: arguments
Returns:
test reader
"""
file_list = os.path.join(settings.data_dir, 'val_list.txt')
assert os.path.isfile(
file_list), "{} doesn't exist, please check data list path".format(
file_list)
return self._reader_creator(
settings,
file_list,
'test',
shuffle=False,
data_dir=settings.data_dir)
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --log_dir ./mylog.time train.py --use_data_parallel 1 --batch_size=256 --reader_thread=8 --total_images=1281167 --class_dim=1000 --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --data_dir=../../PaddleCV/image_classification/data/ILSVRC2012 --l2_decay=3e-5 --model=MobileNetV1
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --log_dir ./mylog.time train.py --use_data_parallel 1 --batch_size=256 --reader_thread=8 --total_images=1281167 --class_dim=1000 --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --data_dir=../../PaddleCV/image_classification/data/ILSVRC2012 --l2_decay=3e-5 --model=MobileNetV2
export CUDA_VISIBLE_DEVICES=0
python train.py --batch_size=256 --total_images=1281167 --class_dim=1000 --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --data_dir=../../PaddleCV/image_classification/data/ILSVRC2012 --l2_decay=3e-5 --model=MobileNetV1
export CUDA_VISIBLE_DEVICES=0
python train.py --batch_size=128 --total_images=1281167 --class_dim=1000 --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --data_dir=../../PaddleCV/image_classification/data/ILSVRC2012 --model=MobileNetV2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mobilenet_v1 import *
from mobilenet_v2 import *
import os
import numpy as np
import time
import sys
import sys
import numpy as np
import argparse
import ast
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
import math
import sys
import reader
from utils import *
IMAGENET1000 = 1281167
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
args = parse_args()
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
print_arguments(args)
def eval(net, test_data_loader, eop):
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
t_last = 0
for img, label in test_data_loader():
t1 = time.time()
label = to_variable(label.numpy().astype('int64').reshape(
int(args.batch_size / paddle.fluid.core.get_cuda_device_count()),
1))
out = net(img)
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
loss = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_loss = fluid.layers.mean(x=loss)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
t2 = time.time()
print( "test | epoch id: %d, avg_loss %0.5f acc_top1 %0.5f acc_top5 %0.5f %2.4f sec read_t:%2.4f" % \
(eop, avg_loss.numpy(), acc_top1.numpy(), acc_top5.numpy(), t2 - t1 , t1 - t_last))
sys.stdout.flush()
total_loss += avg_loss.numpy()
total_acc1 += acc_top1.numpy()
total_acc5 += acc_top5.numpy()
total_sample += 1
t_last = time.time()
print("final eval loss %0.3f acc1 %0.3f acc5 %0.3f" % \
(total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))
sys.stdout.flush()
def train_mobilenet():
epoch = args.num_epochs
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
seed = 33
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
net = None
if args.model == "MobileNetV1":
net = MobileNetV1("mobilenet_v1", class_dim=args.class_dim)
para_name = 'mobilenet_v1_params'
elif args.model == "MobileNetV2":
net = MobileNetV2(
name="mobilenet_v2", class_dim=args.class_dim, scale=1.0)
para_name = 'mobilenet_v2_params'
else:
print(
"wrong model name, please try model = MobileNetV1 or MobileNetV2"
)
exit()
optimizer = create_optimizer(args)
if args.use_data_parallel:
net = fluid.dygraph.parallel.DataParallel(net, strategy)
train_data_loader, train_data = utility.create_data_loader(
is_train=True, args=args)
test_data_loader, test_data = utility.create_data_loader(
is_train=False, args=args)
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
imagenet_reader = reader.ImageNetReader(0)
train_reader = imagenet_reader.train(settings=args)
test_reader = imagenet_reader.val(settings=args)
train_data_loader.set_sample_list_generator(train_reader, place)
test_data_loader.set_sample_list_generator(test_reader, place)
for eop in range(epoch):
if num_trainers > 1:
imagenet_reader.set_shuffle_seed(eop + (
args.random_seed if args.random_seed else 0))
net.train()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
batch_id = 0
t_last = 0
for img, label in train_data_loader():
t1 = time.time()
label = to_variable(label.numpy().astype('int64').reshape(
int(args.batch_size /
paddle.fluid.core.get_cuda_device_count()), 1))
t_start = time.time()
out = net(img)
t_end = time.time()
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
loss = fluid.layers.cross_entropy(
input=softmax_out, label=label)
avg_loss = fluid.layers.mean(x=loss)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
t_start_back = time.time()
if args.use_data_parallel:
avg_loss = net.scale_loss(avg_loss)
avg_loss.backward()
net.apply_collective_grads()
else:
avg_loss.backward()
t_end_back = time.time()
optimizer.minimize(avg_loss)
net.clear_gradients()
t2 = time.time()
train_batch_elapse = t2 - t1
if batch_id % args.print_step == 0:
print( "epoch id: %d, batch step: %d, avg_loss %0.5f acc_top1 %0.5f acc_top5 %0.5f %2.4f sec net_t:%2.4f back_t:%2.4f read_t:%2.4f" % \
(eop, batch_id, avg_loss.numpy(), acc_top1.numpy(), acc_top5.numpy(), train_batch_elapse,
t_end - t_start, t_end_back - t_start_back, t1 - t_last))
sys.stdout.flush()
total_loss += avg_loss.numpy()
total_acc1 += acc_top1.numpy()
total_acc5 += acc_top5.numpy()
total_sample += 1
batch_id += 1
t_last = time.time()
if args.ce:
print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample))
print("kpis\ttrain_loss\t%0.3f" % (total_loss / total_sample))
print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f %2.4f sec" % \
(eop, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample, train_batch_elapse))
net.eval()
eval(net, test_data_loader, eop)
save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.save_dygraph(net.state_dict(), para_name)
if __name__ == '__main__':
train_mobilenet()
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from .optimizer import cosine_decay, lr_warmup, cosine_decay_with_warmup, exponential_decay_with_warmup, Optimizer, create_optimizer
from .utility import add_arguments, print_arguments, parse_args, check_gpu, check_args, check_version, init_model, save_model, create_data_loader, print_info, best_strategy_compiled, init_model, save_model, ExponentialMovingAverage
"""
This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
"""
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
class ImageNetPolicy(object):
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy(object):
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy(object):
""" Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), SubPolicy(
0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), SubPolicy(
0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), SubPolicy(
0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), SubPolicy(
0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), SubPolicy(
0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]
def __call__(self, img, policy_idx=None):
if policy_idx is None or not isinstance(policy_idx, int):
policy_idx = random.randint(0, len(self.policies) - 1)
else:
policy_idx = policy_idx % len(self.policies)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(self,
p1,
operation1,
magnitude_idx1,
p2,
operation2,
magnitude_idx2,
fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot,
Image.new("RGBA", rot.size, (128, ) * 4),
rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1:
img = self.operation1(img, self.magnitude1)
if random.random() < self.p2:
img = self.operation2(img, self.magnitude2)
return img
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(args, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
t.transpile(
envs["trainer_id"],
trainers=','.join(envs["trainer_endpoints"]),
current_endpoint=envs["current_endpoint"],
startup_program=startup_prog,
program=main_prog)
def pserver_prepare(args, train_prog, startup_prog):
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = args.split_var
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
training_role = envs["training_role"]
t.transpile(
envs["trainer_id"],
program=train_prog,
pservers=envs["pserver_endpoints"],
trainers=envs["num_trainers"],
sync_mode=not args.async_mode,
startup_program=startup_prog)
if training_role == "PSERVER":
pserver_program = t.get_pserver_program(envs["current_endpoint"])
pserver_startup_program = t.get_startup_program(
envs["current_endpoint"],
pserver_program,
startup_program=startup_prog)
return pserver_program, pserver_startup_program
elif training_role == "TRAINER":
train_program = t.get_trainer_program()
return train_program, startup_prog
else:
raise ValueError(
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def nccl2_prepare_paddle(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
print("PADDLE_TRAINERS_NUM", num_trainers)
print("PADDLE_TRAINER_ID", trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare_paddle(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
def cosine_decay(learning_rate, step_each_epoch, epochs=120):
"""Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
"""
global_step = _decay_step_counter()
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
decayed_lr = learning_rate * \
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
return decayed_lr
def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120):
"""Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
decrease lr for every mini-batch and start with warmup.
"""
global_step = _decay_step_counter()
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(5), force_cpu=True)
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < warmup_epoch):
decayed_lr = learning_rate * (global_step /
(step_each_epoch * warmup_epoch))
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
with switch.default():
decayed_lr = learning_rate * \
(ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / (epochs * step_each_epoch))) + 1)/2
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
return lr
def exponential_decay_with_warmup(learning_rate,
step_each_epoch,
decay_epochs,
decay_rate=0.97,
warm_up_epoch=5.0):
"""Applies exponential decay to the learning rate.
"""
global_step = _decay_step_counter()
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True)
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < warmup_epoch):
decayed_lr = learning_rate * (global_step /
(step_each_epoch * warmup_epoch))
fluid.layers.assign(input=decayed_lr, output=lr)
with switch.default():
div_res = (
global_step - warmup_epoch * step_each_epoch) / decay_epochs
div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate**div_res)
fluid.layers.assign(input=decayed_lr, output=lr)
return lr
def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
""" Applies linear learning rate warmup for distributed training
Argument learning_rate can be float or a Variable
lr = lr + (warmup_rate * step / warmup_steps)
"""
assert (isinstance(end_lr, float))
assert (isinstance(start_lr, float))
linear_step = end_lr - start_lr
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate_warmup")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
decayed_lr = start_lr + linear_step * (global_step /
warmup_steps)
fluid.layers.tensor.assign(decayed_lr, lr)
with switch.default():
fluid.layers.tensor.assign(learning_rate, lr)
return lr
class Optimizer(object):
"""A class used to represent several optimizer methods
Attributes:
batch_size: batch size on all devices.
lr: learning rate.
lr_strategy: learning rate decay strategy.
l2_decay: l2_decay parameter.
momentum_rate: momentum rate when using Momentum optimizer.
step_epochs: piecewise decay steps.
num_epochs: number of total epochs.
total_images: total images.
step: total steps in the an epoch.
"""
def __init__(self, args):
self.batch_size = args.batch_size
self.lr = args.lr
self.lr_strategy = args.lr_strategy
self.l2_decay = args.l2_decay
self.momentum_rate = args.momentum_rate
self.step_epochs = args.step_epochs
self.num_epochs = args.num_epochs
self.warm_up_epochs = args.warm_up_epochs
self.decay_epochs = args.decay_epochs
self.decay_rate = args.decay_rate
self.total_images = args.total_images
self.step = int(math.ceil(float(self.total_images) / self.batch_size))
def piecewise_decay(self):
"""piecewise decay with Momentum optimizer
Returns:
a piecewise_decay optimizer
"""
bd = [self.step * e for e in self.step_epochs]
lr = [self.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def cosine_decay(self):
"""cosine decay with Momentum optimizer
Returns:
a cosine_decay optimizer
"""
learning_rate = fluid.layers.cosine_decay(
learning_rate=self.lr,
step_each_epoch=self.step,
epochs=self.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def cosine_decay_warmup(self):
"""cosine decay with warmup
Returns:
a cosine_decay_with_warmup optimizer
"""
learning_rate = cosine_decay_with_warmup(
learning_rate=self.lr,
step_each_epoch=self.step,
epochs=self.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def exponential_decay_warmup(self):
"""exponential decay with warmup
Returns:
a exponential_decay_with_warmup optimizer
"""
learning_rate = exponential_decay_with_warmup(
learning_rate=self.lr,
step_each_epoch=self.step,
decay_epochs=self.step * self.decay_epochs,
decay_rate=self.decay_rate,
warm_up_epoch=self.warm_up_epochs)
optimizer = fluid.optimizer.RMSProp(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay),
momentum=self.momentum_rate,
rho=0.9,
epsilon=0.001)
return optimizer
def linear_decay(self):
"""linear decay with Momentum optimizer
Returns:
a linear_decay optimizer
"""
end_lr = 0
learning_rate = fluid.layers.polynomial_decay(
self.lr, self.step, end_lr, power=1)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def adam_decay(self):
"""Adam optimizer
Returns:
an adam_decay optimizer
"""
return fluid.optimizer.Adam(learning_rate=self.lr)
def cosine_decay_RMSProp(self):
"""cosine decay with RMSProp optimizer
Returns:
an cosine_decay_RMSProp optimizer
"""
learning_rate = fluid.layers.cosine_decay(
learning_rate=self.lr,
step_each_epoch=self.step,
epochs=self.num_epochs)
optimizer = fluid.optimizer.RMSProp(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay),
# Apply epsilon=1 on ImageNet dataset.
epsilon=1)
return optimizer
def default_decay(self):
"""default decay
Returns:
default decay optimizer
"""
optimizer = fluid.optimizer.Momentum(
learning_rate=self.lr,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def create_optimizer(args):
Opt = Optimizer(args)
optimizer = getattr(Opt, args.lr_strategy)()
return optimizer
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册