From ae89990bc813af2abcf6d03a5da379c671a65038 Mon Sep 17 00:00:00 2001 From: chajchaj <57249073+chajchaj@users.noreply.github.com> Date: Wed, 15 Jan 2020 20:47:33 +0800 Subject: [PATCH] Add MobileNet v1 and v2 dygraph code (#4188) --- dygraph/mobilenet/RADEME.md | 44 ++ dygraph/mobilenet/mobilenet_v1.py | 239 ++++++++++ dygraph/mobilenet/mobilenet_v2.py | 226 ++++++++++ dygraph/mobilenet/reader.py | 414 ++++++++++++++++++ dygraph/mobilenet/run_mul_v1.sh | 2 + dygraph/mobilenet/run_mul_v2.sh | 2 + dygraph/mobilenet/run_sing_v1.sh | 2 + dygraph/mobilenet/run_sing_v2.sh | 2 + dygraph/mobilenet/train.py | 188 ++++++++ dygraph/mobilenet/utils/__init__.py | 15 + dygraph/mobilenet/utils/autoaugment.py | 245 +++++++++++ dygraph/mobilenet/utils/dist_utils.py | 93 ++++ dygraph/mobilenet/utils/optimizer.py | 299 +++++++++++++ dygraph/mobilenet/utils/utility.py | 576 +++++++++++++++++++++++++ 14 files changed, 2347 insertions(+) create mode 100644 dygraph/mobilenet/RADEME.md create mode 100644 dygraph/mobilenet/mobilenet_v1.py create mode 100644 dygraph/mobilenet/mobilenet_v2.py create mode 100644 dygraph/mobilenet/reader.py create mode 100644 dygraph/mobilenet/run_mul_v1.sh create mode 100644 dygraph/mobilenet/run_mul_v2.sh create mode 100644 dygraph/mobilenet/run_sing_v1.sh create mode 100644 dygraph/mobilenet/run_sing_v2.sh create mode 100644 dygraph/mobilenet/train.py create mode 100644 dygraph/mobilenet/utils/__init__.py create mode 100644 dygraph/mobilenet/utils/autoaugment.py create mode 100755 dygraph/mobilenet/utils/dist_utils.py create mode 100644 dygraph/mobilenet/utils/optimizer.py create mode 100644 dygraph/mobilenet/utils/utility.py diff --git a/dygraph/mobilenet/RADEME.md b/dygraph/mobilenet/RADEME.md new file mode 100644 index 00000000..beee2f1b --- /dev/null +++ b/dygraph/mobilenet/RADEME.md @@ -0,0 +1,44 @@ +**模型简介** + +图像分类是计算机视觉的重要领域,它的目标是将图像分类到预定义的标签。CNN模型在图像分类领域取得了突破的成果,同时模型复杂度也在不断增加。MobileNet是一种小巧而高效CNN模型,本文介绍如何使PaddlePaddle的动态图MobileNet进行图像分类。 + +**代码结构** + + ├── run_mul_v1.sh # 多卡训练启动脚本_v1 + ├── run_mul_v2.sh # 多卡训练启动脚本_v2 + ├── run_sing_v1.sh # 单卡训练启动脚本_v1 + ├── run_sing_v2.sh # 单卡训练启动脚本_v2 + ├── train.py # 训练入口 + ├── mobilenet_v1.py # 网络结构v1 + ├── mobilenet_v2.py # 网络结构v2 + ├── reader.py # 数据reader + ├── utils # 基础工具目录 + +**数据准备** + +请参考:https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification + +**模型训练** + +若使用4卡训练,启动方式如下: + + bash run_mul_v1.sh + bash run_mul_v2.sh +若使用单卡训练,启动方式如下: + + bash run_sing_v1.sh + bash run_sing_v2.sh + +**模型精度** + + Model Top-1 Top-5 + + MobileNetV1 0.707 0.895 + + MobileNetV2 0.626 0.845 + +**参考论文** + +MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications, Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam + +MobileNetV2: Inverted Residuals and Linear Bottlenecks, Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen diff --git a/dygraph/mobilenet/mobilenet_v1.py b/dygraph/mobilenet/mobilenet_v1.py new file mode 100644 index 00000000..f6d35aa0 --- /dev/null +++ b/dygraph/mobilenet/mobilenet_v1.py @@ -0,0 +1,239 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys +import numpy as np +import argparse +import ast +import paddle +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC +from paddle.fluid.dygraph.base import to_variable +from paddle.fluid import framework +import math +import sys + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + name_scope, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + super(ConvBNLayer, self).__init__(name_scope) + + self._conv = Conv2D( + self.full_name(), + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + initializer=MSRA(), name=self.full_name() + "_weights"), + bias_attr=False) + + self._batch_norm = BatchNorm( + self.full_name(), + num_filters, + act=act, + param_attr=ParamAttr(name="_bn" + "_scale"), + bias_attr=ParamAttr(name="_bn" + "_offset"), + moving_mean_name="_bn" + '_mean', + moving_variance_name="_bn" + '_variance') + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class DepthwiseSeparable(fluid.dygraph.Layer): + def __init__(self, + name_scope, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + name=None): + super(DepthwiseSeparable, self).__init__(name_scope) + + self._depthwise_conv = ConvBNLayer( + name_scope="dw", + num_filters=int(num_filters1 * scale), + filter_size=3, + stride=stride, + padding=1, + num_groups=int(num_groups * scale), + use_cudnn=False) + + self._pointwise_conv = ConvBNLayer( + name_scope="sep", + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0) + + def forward(self, inputs): + y = self._depthwise_conv(inputs) + y = self._pointwise_conv(y) + return y + + +class MobileNetV1(fluid.dygraph.Layer): + def __init__(self, name_scope, scale=1.0, class_dim=102): + super(MobileNetV1, self).__init__(name_scope) + self.scale = scale + self.dwsl = [] + + self.conv1 = ConvBNLayer( + name_scope="conv1", + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1) + + dws21 = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv2_1", + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale), + name="conv2_1") + self.dwsl.append(dws21) + + dws22 = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv2_2", + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=2, + scale=scale), + name="conv2_2") + self.dwsl.append(dws22) + + dws31 = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv3_1", + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale), + name="conv3_1") + self.dwsl.append(dws31) + + dws32 = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv3_2", + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=2, + scale=scale), + name="conv3_2") + self.dwsl.append(dws32) + + dws41 = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv4_1", + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale), + name="conv4_1") + self.dwsl.append(dws41) + + dws42 = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv4_2", + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=2, + scale=scale), + name="conv4_2") + self.dwsl.append(dws42) + + for i in range(5): + tmp = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv5_" + str(i + 1), + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + scale=scale), + name="conv5_" + str(i + 1)) + self.dwsl.append(tmp) + + dws56 = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv5_6", + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=2, + scale=scale), + name="conv5_6") + self.dwsl.append(dws56) + + dws6 = self.add_sublayer( + sublayer=DepthwiseSeparable( + name_scope="conv6", + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=1, + scale=scale), + name="conv6") + self.dwsl.append(dws6) + + self.pool2d_avg = Pool2D( + name_scope="pool", pool_type='avg', global_pooling=True) + + self.out = FC(name_scope="fc", + size=class_dim, + param_attr=ParamAttr( + initializer=MSRA(), + name=self.full_name() + "fc7_weights"), + bias_attr=ParamAttr(name="fc7_offset")) + + def forward(self, inputs): + y = self.conv1(inputs) + idx = 0 + for dws in self.dwsl: + y = dws(y) + y = self.pool2d_avg(y) + y = self.out(y) + return y diff --git a/dygraph/mobilenet/mobilenet_v2.py b/dygraph/mobilenet/mobilenet_v2.py new file mode 100644 index 00000000..f49f8632 --- /dev/null +++ b/dygraph/mobilenet/mobilenet_v2.py @@ -0,0 +1,226 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import time +import sys +import sys +import numpy as np +import argparse +import ast +import paddle +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC +from paddle.fluid.dygraph.base import to_variable + +from paddle.fluid import framework + +import math +import sys + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + name=None, + use_cudnn=True): + super(ConvBNLayer, self).__init__(name) + + tmp_param = ParamAttr(name=name + "_weights") + self._conv = Conv2D( + self.full_name(), + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=tmp_param, + bias_attr=False) + + self._batch_norm = BatchNorm( + self.full_name(), + num_filters, + param_attr=ParamAttr(name=name + "_bn" + "_scale"), + bias_attr=ParamAttr(name=name + "_bn" + "_offset"), + moving_mean_name=name + "_bn" + '_mean', + moving_variance_name=name + "_bn" + '_variance') + + def forward(self, inputs, if_act=True): + y = self._conv(inputs) + y = self._batch_norm(y) + if if_act: + y = fluid.layers.relu6(y) + return y + + +class InvertedResidualUnit(fluid.dygraph.Layer): + def __init__(self, + num_in_filter, + num_filters, + stride, + filter_size, + padding, + expansion_factor, + name=None): + super(InvertedResidualUnit, self).__init__(name) + num_expfilter = int(round(num_in_filter * expansion_factor)) + self._expand_conv = ConvBNLayer( + name=name + "_expand", + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding=0, + num_groups=1) + + self._bottleneck_conv = ConvBNLayer( + name=name + "_dwise", + num_filters=num_expfilter, + filter_size=filter_size, + stride=stride, + padding=padding, + num_groups=num_expfilter, + use_cudnn=False) + + self._linear_conv = ConvBNLayer( + name=name + "_linear", + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + num_groups=1) + + def forward(self, inputs, ifshortcut): + y = self._expand_conv(inputs, if_act=True) + y = self._bottleneck_conv(y, if_act=True) + y = self._linear_conv(y, if_act=False) + if ifshortcut: + y = fluid.layers.elementwise_add(inputs, y) + return y + + +class InvresiBlocks(fluid.dygraph.Layer): + def __init__(self, in_c, t, c, n, s, name=None): + super(InvresiBlocks, self).__init__(name) + + self._first_block = InvertedResidualUnit( + name=name + "_1", + num_in_filter=in_c, + num_filters=c, + stride=s, + filter_size=3, + padding=1, + expansion_factor=t) + + self._inv_blocks = [] + for i in range(1, n): + tmp = self.add_sublayer( + sublayer=InvertedResidualUnit( + name=name + "_" + str(i + 1), + num_in_filter=c, + num_filters=c, + stride=1, + filter_size=3, + padding=1, + expansion_factor=t), + name=name + "_" + str(i + 1)) + self._inv_blocks.append(tmp) + + def forward(self, inputs): + y = self._first_block(inputs, ifshortcut=False) + for inv_block in self._inv_blocks: + y = inv_block(y, ifshortcut=True) + return y + + +class MobileNetV2(fluid.dygraph.Layer): + def __init__(self, name, class_dim=1000, scale=1.0): + super(MobileNetV2, self).__init__(name) + self.scale = scale + self.class_dim = class_dim + + bottleneck_params_list = [ + (1, 16, 1, 1), + (6, 24, 2, 2), + (6, 32, 3, 2), + (6, 64, 4, 2), + (6, 96, 3, 1), + (6, 160, 3, 2), + (6, 320, 1, 1), + ] + + #1. conv1 + self._conv1 = ConvBNLayer( + name="conv1_1", + num_filters=int(32 * scale), + filter_size=3, + stride=2, + padding=1) + + #2. bottleneck sequences + self._invl = [] + i = 1 + in_c = int(32 * scale) + for layer_setting in bottleneck_params_list: + t, c, n, s = layer_setting + i += 1 + tmp = self.add_sublayer( + sublayer=InvresiBlocks( + name='conv' + str(i), + in_c=in_c, + t=t, + c=int(c * scale), + n=n, + s=s), + name='conv' + str(i)) + self._invl.append(tmp) + in_c = int(c * scale) + + #3. last_conv + self._conv9 = ConvBNLayer( + name="conv9", + num_filters=int(1280 * scale) if scale > 1.0 else 1280, + filter_size=1, + stride=1, + padding=0) + + #4. pool + self._pool2d_avg = Pool2D( + name_scope="pool", pool_type='avg', global_pooling=True) + + #5. fc + tmp_param = ParamAttr(name="fc10_weights") + self._fc = FC(name_scope="fc", + size=class_dim, + param_attr=tmp_param, + bias_attr=ParamAttr(name="fc10_offset")) + + def forward(self, inputs): + y = self._conv1(inputs, if_act=True) + for inv in self._invl: + y = inv(y) + y = self._conv9(y, if_act=True) + y = self._pool2d_avg(y) + y = self._fc(y) + return y diff --git a/dygraph/mobilenet/reader.py b/dygraph/mobilenet/reader.py new file mode 100644 index 00000000..b96d1366 --- /dev/null +++ b/dygraph/mobilenet/reader.py @@ -0,0 +1,414 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import sys +import os +import math +import random +import functools +import numpy as np +import cv2 + +import paddle +from paddle import fluid +from utils.autoaugment import ImageNetPolicy +from PIL import Image + +policy = None + +random.seed(0) +np.random.seed(0) + + +def rotate_image(img): + """rotate image + + Args: + img: image data + + Returns: + rotated image data + """ + (h, w) = img.shape[:2] + center = (w / 2, h / 2) + angle = np.random.randint(-10, 11) + M = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated = cv2.warpAffine(img, M, (w, h)) + return rotated + + +def random_crop(img, size, settings, scale=None, ratio=None, + interpolation=None): + """random crop image + + Args: + img: image data + size: crop size + settings: arguments + scale: scale parameter + ratio: ratio parameter + + Returns: + random cropped image data + """ + lower_scale = settings.lower_scale + lower_ratio = settings.lower_ratio + upper_ratio = settings.upper_ratio + scale = [lower_scale, 1.0] if scale is None else scale + ratio = [lower_ratio, upper_ratio] if ratio is None else ratio + + aspect_ratio = math.sqrt(np.random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + bound = min((float(img.shape[0]) / img.shape[1]) / (h**2), + (float(img.shape[1]) / img.shape[0]) / (w**2)) + + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img.shape[0] * img.shape[1] * np.random.uniform(scale_min, + scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = np.random.randint(0, img.shape[0] - h + 1) + j = np.random.randint(0, img.shape[1] - w + 1) + img = img[i:i + h, j:j + w, :] + + if interpolation: + resized = cv2.resize(img, (size, size), interpolation=interpolation) + else: + resized = cv2.resize(img, (size, size)) + return resized + + +#NOTE:(2019/08/08) distort color func is not implemented +def distort_color(img): + """distort image color + + Args: + img: image data + + Returns: + distorted color image data + """ + return img + + +def resize_short(img, target_size, interpolation=None): + """resize image + + Args: + img: image data + target_size: resize short target size + interpolation: interpolation mode + + Returns: + resized image data + """ + percent = float(target_size) / min(img.shape[0], img.shape[1]) + resized_width = int(round(img.shape[1] * percent)) + resized_height = int(round(img.shape[0] * percent)) + if interpolation: + resized = cv2.resize( + img, (resized_width, resized_height), interpolation=interpolation) + else: + resized = cv2.resize(img, (resized_width, resized_height)) + return resized + + +def crop_image(img, target_size, center): + """crop image + + Args: + img: images data + target_size: crop target size + center: crop mode + + Returns: + img: cropped image data + """ + height, width = img.shape[:2] + size = target_size + if center == True: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img[h_start:h_end, w_start:w_end, :] + return img + + +def create_mixup_reader(settings, rd): + """ + """ + + class context: + tmp_mix = [] + tmp_l1 = [] + tmp_l2 = [] + tmp_lam = [] + + alpha = settings.mixup_alpha + + def fetch_data(): + for item in rd(): + yield item + + def mixup_data(): + for data_list in fetch_data(): + if alpha > 0.: + lam = np.random.beta(alpha, alpha) + else: + lam = 1. + l1 = np.array(data_list) + l2 = np.random.permutation(l1) + mixed_l = [ + l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1)) + ] + yield (mixed_l, l1, l2, lam) + + def mixup_reader(): + for context.tmp_mix, context.tmp_l1, context.tmp_l2, context.tmp_lam in mixup_data( + ): + for i in range(len(context.tmp_mix)): + mixed_l = context.tmp_mix[i] + l1 = context.tmp_l1[i] + l2 = context.tmp_l2[i] + lam = context.tmp_lam + yield (mixed_l, int(l1[1]), int(l2[1]), float(lam)) + + return mixup_reader + + +def process_image(sample, settings, mode, color_jitter, rotate): + """ process_image """ + + mean = settings.image_mean + std = settings.image_std + crop_size = settings.crop_size + + img_path = sample[0] + img = cv2.imread(img_path) + + if mode == 'train': + if rotate: + img = rotate_image(img) + if crop_size > 0: + img = random_crop( + img, crop_size, settings, interpolation=settings.interpolation) + if color_jitter: + img = distort_color(img) + if np.random.randint(0, 2) == 1: + img = img[:, ::-1, :] + else: + if crop_size > 0: + target_size = settings.resize_short_size + img = resize_short( + img, target_size, interpolation=settings.interpolation) + img = crop_image(img, target_size=crop_size, center=True) + + img = img[:, :, ::-1] + + if 'use_aa' in settings and settings.use_aa and mode == 'train': + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + img = policy(img) + img = np.asarray(img) + + img = img.astype('float32').transpose((2, 0, 1)) / 255 + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + + if mode == 'train' or mode == 'val': + return (img, sample[1]) + elif mode == 'test': + return (img, ) + + +def process_batch_data(input_data, settings, mode, color_jitter, rotate): + batch_data = [] + for sample in input_data: + if os.path.isfile(sample[0]): + batch_data.append( + process_image(sample, settings, mode, color_jitter, rotate)) + else: + print("File not exist : %s" % sample[0]) + return batch_data + + +class ImageNetReader: + def __init__(self, seed=None): + self.shuffle_seed = seed + + def set_shuffle_seed(self, seed): + assert isinstance(seed, int), "shuffle seed must be int" + self.shuffle_seed = seed + + def _reader_creator(self, + settings, + file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=None): + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + if mode == 'test': + batch_size = 1 + else: + batch_size = settings.batch_size / paddle.fluid.core.get_cuda_device_count( + ) + + def reader(): + def read_file_list(): + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if mode != "test" and len(full_lines) < settings.batch_size: + print( + "Warning: The number of the whole data ({}) is smaller than the batch_size ({}), and drop_last is turnning on, so nothing will feed in program, Terminated now. Please reset batch_size to a smaller number or feed more data!" + .format(len(full_lines), settings.batch_size)) + os._exit(1) + if num_trainers > 1 and mode == "train": + assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!" + np.random.RandomState(self.shuffle_seed).shuffle( + full_lines) + elif shuffle: + assert self.shuffle_seed is not None, "multiprocess train, shuffle seed must be set!" + np.random.RandomState(self.shuffle_seed).shuffle( + full_lines) + + batch_data = [] + for line in full_lines: + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + batch_data.append([img_path, int(label)]) + if len(batch_data) == batch_size: + if mode == 'train' or mode == 'val' or mode == 'test': + yield batch_data + + batch_data = [] + + return read_file_list + + data_reader = reader() + if mode == 'train' and num_trainers > 1: + assert self.shuffle_seed is not None, \ + "If num_trainers > 1, the shuffle_seed must be set, because " \ + "the order of batch data generated by reader " \ + "must be the same in the respective processes." + data_reader = paddle.fluid.contrib.reader.distributed_batch_reader( + data_reader) + + mapper = functools.partial( + process_batch_data, + settings=settings, + mode=mode, + color_jitter=color_jitter, + rotate=rotate) + + ret = fluid.io.xmap_readers( + mapper, + data_reader, + settings.reader_thread, + settings.reader_buf_size, + order=False) + + return ret + + def train(self, settings): + """Create a reader for trainning + + Args: + settings: arguments + + Returns: + train reader + """ + file_list = os.path.join(settings.data_dir, 'train_list.txt') + assert os.path.isfile( + file_list), "{} doesn't exist, please check data list path".format( + file_list) + + if 'use_aa' in settings and settings.use_aa: + global policy + policy = ImageNetPolicy() + + reader = self._reader_creator( + settings, + file_list, + 'train', + shuffle=True, + color_jitter=False, + rotate=False, + data_dir=settings.data_dir) + + if settings.use_mixup == True: + reader = create_mixup_reader(settings, reader) + reader = fluid.io.batch( + reader, + batch_size=int(settings.batch_size / + paddle.fluid.core.get_cuda_device_count()), + drop_last=True) + return reader + + def val(self, settings): + """Create a reader for eval + + Args: + settings: arguments + + Returns: + eval reader + """ + file_list = os.path.join(settings.data_dir, 'val_list.txt') + + assert os.path.isfile( + file_list), "{} doesn't exist, please check data list path".format( + file_list) + + return self._reader_creator( + settings, + file_list, + 'val', + shuffle=False, + data_dir=settings.data_dir) + + def test(self, settings): + """Create a reader for testing + + Args: + settings: arguments + + Returns: + test reader + """ + file_list = os.path.join(settings.data_dir, 'val_list.txt') + + assert os.path.isfile( + file_list), "{} doesn't exist, please check data list path".format( + file_list) + return self._reader_creator( + settings, + file_list, + 'test', + shuffle=False, + data_dir=settings.data_dir) diff --git a/dygraph/mobilenet/run_mul_v1.sh b/dygraph/mobilenet/run_mul_v1.sh new file mode 100644 index 00000000..27f5fd27 --- /dev/null +++ b/dygraph/mobilenet/run_mul_v1.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --log_dir ./mylog.time train.py --use_data_parallel 1 --batch_size=256 --reader_thread=8 --total_images=1281167 --class_dim=1000 --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --data_dir=../../PaddleCV/image_classification/data/ILSVRC2012 --l2_decay=3e-5 --model=MobileNetV1 diff --git a/dygraph/mobilenet/run_mul_v2.sh b/dygraph/mobilenet/run_mul_v2.sh new file mode 100644 index 00000000..aaa11ee3 --- /dev/null +++ b/dygraph/mobilenet/run_mul_v2.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --log_dir ./mylog.time train.py --use_data_parallel 1 --batch_size=256 --reader_thread=8 --total_images=1281167 --class_dim=1000 --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --data_dir=../../PaddleCV/image_classification/data/ILSVRC2012 --l2_decay=3e-5 --model=MobileNetV2 diff --git a/dygraph/mobilenet/run_sing_v1.sh b/dygraph/mobilenet/run_sing_v1.sh new file mode 100644 index 00000000..a2dadee1 --- /dev/null +++ b/dygraph/mobilenet/run_sing_v1.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES=0 +python train.py --batch_size=256 --total_images=1281167 --class_dim=1000 --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --data_dir=../../PaddleCV/image_classification/data/ILSVRC2012 --l2_decay=3e-5 --model=MobileNetV1 diff --git a/dygraph/mobilenet/run_sing_v2.sh b/dygraph/mobilenet/run_sing_v2.sh new file mode 100644 index 00000000..61234a27 --- /dev/null +++ b/dygraph/mobilenet/run_sing_v2.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES=0 +python train.py --batch_size=128 --total_images=1281167 --class_dim=1000 --image_shape=3,224,224 --model_save_dir=output/ --lr_strategy=piecewise_decay --lr=0.1 --data_dir=../../PaddleCV/image_classification/data/ILSVRC2012 --model=MobileNetV2 diff --git a/dygraph/mobilenet/train.py b/dygraph/mobilenet/train.py new file mode 100644 index 00000000..0c057e63 --- /dev/null +++ b/dygraph/mobilenet/train.py @@ -0,0 +1,188 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mobilenet_v1 import * +from mobilenet_v2 import * +import os +import numpy as np +import time +import sys +import sys +import numpy as np +import argparse +import ast +import paddle +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC +from paddle.fluid.dygraph.base import to_variable + +from paddle.fluid import framework + +import math +import sys +import reader +from utils import * + +IMAGENET1000 = 1281167 +base_lr = 0.1 +momentum_rate = 0.9 +l2_decay = 1e-4 + +args = parse_args() +if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: + print_arguments(args) + + +def eval(net, test_data_loader, eop): + total_loss = 0.0 + total_acc1 = 0.0 + total_acc5 = 0.0 + total_sample = 0 + t_last = 0 + for img, label in test_data_loader(): + t1 = time.time() + label = to_variable(label.numpy().astype('int64').reshape( + int(args.batch_size / paddle.fluid.core.get_cuda_device_count()), + 1)) + out = net(img) + softmax_out = fluid.layers.softmax(out, use_cudnn=False) + loss = fluid.layers.cross_entropy(input=softmax_out, label=label) + avg_loss = fluid.layers.mean(x=loss) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + t2 = time.time() + print( "test | epoch id: %d, avg_loss %0.5f acc_top1 %0.5f acc_top5 %0.5f %2.4f sec read_t:%2.4f" % \ + (eop, avg_loss.numpy(), acc_top1.numpy(), acc_top5.numpy(), t2 - t1 , t1 - t_last)) + sys.stdout.flush() + total_loss += avg_loss.numpy() + total_acc1 += acc_top1.numpy() + total_acc5 += acc_top5.numpy() + total_sample += 1 + t_last = time.time() + print("final eval loss %0.3f acc1 %0.3f acc5 %0.3f" % \ + (total_loss / total_sample, \ + total_acc1 / total_sample, total_acc5 / total_sample)) + sys.stdout.flush() + + +def train_mobilenet(): + epoch = args.num_epochs + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ + if args.use_data_parallel else fluid.CUDAPlace(0) + with fluid.dygraph.guard(place): + if args.ce: + print("ce mode") + seed = 33 + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + if args.use_data_parallel: + strategy = fluid.dygraph.parallel.prepare_context() + + net = None + if args.model == "MobileNetV1": + net = MobileNetV1("mobilenet_v1", class_dim=args.class_dim) + para_name = 'mobilenet_v1_params' + elif args.model == "MobileNetV2": + net = MobileNetV2( + name="mobilenet_v2", class_dim=args.class_dim, scale=1.0) + para_name = 'mobilenet_v2_params' + else: + print( + "wrong model name, please try model = MobileNetV1 or MobileNetV2" + ) + exit() + + optimizer = create_optimizer(args) + if args.use_data_parallel: + net = fluid.dygraph.parallel.DataParallel(net, strategy) + train_data_loader, train_data = utility.create_data_loader( + is_train=True, args=args) + test_data_loader, test_data = utility.create_data_loader( + is_train=False, args=args) + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + imagenet_reader = reader.ImageNetReader(0) + train_reader = imagenet_reader.train(settings=args) + test_reader = imagenet_reader.val(settings=args) + train_data_loader.set_sample_list_generator(train_reader, place) + test_data_loader.set_sample_list_generator(test_reader, place) + for eop in range(epoch): + if num_trainers > 1: + imagenet_reader.set_shuffle_seed(eop + ( + args.random_seed if args.random_seed else 0)) + net.train() + total_loss = 0.0 + total_acc1 = 0.0 + total_acc5 = 0.0 + total_sample = 0 + batch_id = 0 + t_last = 0 + for img, label in train_data_loader(): + t1 = time.time() + label = to_variable(label.numpy().astype('int64').reshape( + int(args.batch_size / + paddle.fluid.core.get_cuda_device_count()), 1)) + t_start = time.time() + out = net(img) + t_end = time.time() + softmax_out = fluid.layers.softmax(out, use_cudnn=False) + loss = fluid.layers.cross_entropy( + input=softmax_out, label=label) + avg_loss = fluid.layers.mean(x=loss) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + t_start_back = time.time() + if args.use_data_parallel: + avg_loss = net.scale_loss(avg_loss) + avg_loss.backward() + net.apply_collective_grads() + else: + avg_loss.backward() + t_end_back = time.time() + optimizer.minimize(avg_loss) + net.clear_gradients() + t2 = time.time() + train_batch_elapse = t2 - t1 + if batch_id % args.print_step == 0: + print( "epoch id: %d, batch step: %d, avg_loss %0.5f acc_top1 %0.5f acc_top5 %0.5f %2.4f sec net_t:%2.4f back_t:%2.4f read_t:%2.4f" % \ + (eop, batch_id, avg_loss.numpy(), acc_top1.numpy(), acc_top5.numpy(), train_batch_elapse, + t_end - t_start, t_end_back - t_start_back, t1 - t_last)) + sys.stdout.flush() + total_loss += avg_loss.numpy() + total_acc1 += acc_top1.numpy() + total_acc5 += acc_top5.numpy() + total_sample += 1 + batch_id += 1 + t_last = time.time() + if args.ce: + print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample)) + print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample)) + print("kpis\ttrain_loss\t%0.3f" % (total_loss / total_sample)) + print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f %2.4f sec" % \ + (eop, batch_id, total_loss / total_sample, \ + total_acc1 / total_sample, total_acc5 / total_sample, train_batch_elapse)) + net.eval() + eval(net, test_data_loader, eop) + save_parameters = (not args.use_data_parallel) or ( + args.use_data_parallel and + fluid.dygraph.parallel.Env().local_rank == 0) + if save_parameters: + fluid.save_dygraph(net.state_dict(), para_name) + + +if __name__ == '__main__': + train_mobilenet() diff --git a/dygraph/mobilenet/utils/__init__.py b/dygraph/mobilenet/utils/__init__.py new file mode 100644 index 00000000..4677e453 --- /dev/null +++ b/dygraph/mobilenet/utils/__init__.py @@ -0,0 +1,15 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +from .optimizer import cosine_decay, lr_warmup, cosine_decay_with_warmup, exponential_decay_with_warmup, Optimizer, create_optimizer +from .utility import add_arguments, print_arguments, parse_args, check_gpu, check_args, check_version, init_model, save_model, create_data_loader, print_info, best_strategy_compiled, init_model, save_model, ExponentialMovingAverage diff --git a/dygraph/mobilenet/utils/autoaugment.py b/dygraph/mobilenet/utils/autoaugment.py new file mode 100644 index 00000000..c17bf963 --- /dev/null +++ b/dygraph/mobilenet/utils/autoaugment.py @@ -0,0 +1,245 @@ +""" +This code is based on https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py +""" +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random + + +class ImageNetPolicy(object): + """ Randomly choose one of the best 24 Sub-policies on ImageNet. + + Example: + >>> policy = ImageNetPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> ImageNetPolicy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), + SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), + SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), + SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), + SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), + SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), + SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), + SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), + SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), + SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), + SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment ImageNet Policy" + + +class CIFAR10Policy(object): + """ Randomly choose one of the best 25 Sub-policies on CIFAR10. + + Example: + >>> policy = CIFAR10Policy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> CIFAR10Policy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), + SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), + SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), + SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), + SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), + SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), + SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), + SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), + SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), + SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), + SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), + SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), + SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), + SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), + SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), + SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), + SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), + SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), + SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), + SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment CIFAR10 Policy" + + +class SVHNPolicy(object): + """ Randomly choose one of the best 25 Sub-policies on SVHN. + + Example: + >>> policy = SVHNPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> SVHNPolicy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), + SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), + SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), + SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), + SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), + SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), + SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), + SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), + SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), + SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), SubPolicy( + 0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), SubPolicy( + 0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), SubPolicy( + 0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), + SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), SubPolicy( + 0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), SubPolicy( + 0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), + SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) + ] + + def __call__(self, img, policy_idx=None): + if policy_idx is None or not isinstance(policy_idx, int): + policy_idx = random.randint(0, len(self.policies) - 1) + else: + policy_idx = policy_idx % len(self.policies) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment SVHN Policy" + + +class SubPolicy(object): + def __init__(self, + p1, + operation1, + magnitude_idx1, + p2, + operation2, + magnitude_idx2, + fillcolor=(128, 128, 128)): + ranges = { + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 150 / 331, 10), + "translateY": np.linspace(0, 150 / 331, 10), + "rotate": np.linspace(0, 30, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), + "solarize": np.linspace(256, 0, 10), + "contrast": np.linspace(0.0, 0.9, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.9, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10 + } + + # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, + Image.new("RGBA", rot.size, (128, ) * 4), + rot).convert(img.mode) + + func = { + "shearX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * random.choice([-1, 1])), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + self.p1 = p1 + self.operation1 = func[operation1] + self.magnitude1 = ranges[operation1][magnitude_idx1] + self.p2 = p2 + self.operation2 = func[operation2] + self.magnitude2 = ranges[operation2][magnitude_idx2] + + def __call__(self, img): + if random.random() < self.p1: + img = self.operation1(img, self.magnitude1) + if random.random() < self.p2: + img = self.operation2(img, self.magnitude2) + return img diff --git a/dygraph/mobilenet/utils/dist_utils.py b/dygraph/mobilenet/utils/dist_utils.py new file mode 100755 index 00000000..29df3d3b --- /dev/null +++ b/dygraph/mobilenet/utils/dist_utils.py @@ -0,0 +1,93 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import paddle.fluid as fluid + + +def nccl2_prepare(args, startup_prog, main_prog): + config = fluid.DistributeTranspilerConfig() + config.mode = "nccl2" + t = fluid.DistributeTranspiler(config=config) + + envs = args.dist_env + + t.transpile( + envs["trainer_id"], + trainers=','.join(envs["trainer_endpoints"]), + current_endpoint=envs["current_endpoint"], + startup_program=startup_prog, + program=main_prog) + + +def pserver_prepare(args, train_prog, startup_prog): + config = fluid.DistributeTranspilerConfig() + config.slice_var_up = args.split_var + t = fluid.DistributeTranspiler(config=config) + envs = args.dist_env + training_role = envs["training_role"] + + t.transpile( + envs["trainer_id"], + program=train_prog, + pservers=envs["pserver_endpoints"], + trainers=envs["num_trainers"], + sync_mode=not args.async_mode, + startup_program=startup_prog) + if training_role == "PSERVER": + pserver_program = t.get_pserver_program(envs["current_endpoint"]) + pserver_startup_program = t.get_startup_program( + envs["current_endpoint"], + pserver_program, + startup_program=startup_prog) + return pserver_program, pserver_startup_program + elif training_role == "TRAINER": + train_program = t.get_trainer_program() + return train_program, startup_prog + else: + raise ValueError( + 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER' + ) + + +def nccl2_prepare_paddle(trainer_id, startup_prog, main_prog): + config = fluid.DistributeTranspilerConfig() + config.mode = "nccl2" + t = fluid.DistributeTranspiler(config=config) + t.transpile( + trainer_id, + trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'), + current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'), + startup_program=startup_prog, + program=main_prog) + + +def prepare_for_multi_process(exe, build_strategy, train_prog): + # prepare for multi-process + trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0)) + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + if num_trainers < 2: return + print("PADDLE_TRAINERS_NUM", num_trainers) + print("PADDLE_TRAINER_ID", trainer_id) + build_strategy.num_trainers = num_trainers + build_strategy.trainer_id = trainer_id + # NOTE(zcd): use multi processes to train the model, + # and each process use one GPU card. + startup_prog = fluid.Program() + nccl2_prepare_paddle(trainer_id, startup_prog, train_prog) + # the startup_prog are run two times, but it doesn't matter. + exe.run(startup_prog) diff --git a/dygraph/mobilenet/utils/optimizer.py b/dygraph/mobilenet/utils/optimizer.py new file mode 100644 index 00000000..7ce4f1b8 --- /dev/null +++ b/dygraph/mobilenet/utils/optimizer.py @@ -0,0 +1,299 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle.fluid as fluid +import paddle.fluid.layers.ops as ops +from paddle.fluid.initializer import init_on_cpu +from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter + + +def cosine_decay(learning_rate, step_each_epoch, epochs=120): + """Applies cosine decay to the learning rate. + lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1) + """ + global_step = _decay_step_counter() + + with init_on_cpu(): + epoch = ops.floor(global_step / step_each_epoch) + decayed_lr = learning_rate * \ + (ops.cos(epoch * (math.pi / epochs)) + 1)/2 + return decayed_lr + + +def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120): + """Applies cosine decay to the learning rate. + lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1) + decrease lr for every mini-batch and start with warmup. + """ + global_step = _decay_step_counter() + lr = fluid.layers.tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate") + + warmup_epoch = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=float(5), force_cpu=True) + + with init_on_cpu(): + epoch = ops.floor(global_step / step_each_epoch) + with fluid.layers.control_flow.Switch() as switch: + with switch.case(epoch < warmup_epoch): + decayed_lr = learning_rate * (global_step / + (step_each_epoch * warmup_epoch)) + fluid.layers.tensor.assign(input=decayed_lr, output=lr) + with switch.default(): + decayed_lr = learning_rate * \ + (ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / (epochs * step_each_epoch))) + 1)/2 + fluid.layers.tensor.assign(input=decayed_lr, output=lr) + return lr + + +def exponential_decay_with_warmup(learning_rate, + step_each_epoch, + decay_epochs, + decay_rate=0.97, + warm_up_epoch=5.0): + """Applies exponential decay to the learning rate. + """ + global_step = _decay_step_counter() + lr = fluid.layers.tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate") + + warmup_epoch = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True) + + with init_on_cpu(): + epoch = ops.floor(global_step / step_each_epoch) + with fluid.layers.control_flow.Switch() as switch: + with switch.case(epoch < warmup_epoch): + decayed_lr = learning_rate * (global_step / + (step_each_epoch * warmup_epoch)) + fluid.layers.assign(input=decayed_lr, output=lr) + with switch.default(): + div_res = ( + global_step - warmup_epoch * step_each_epoch) / decay_epochs + div_res = ops.floor(div_res) + decayed_lr = learning_rate * (decay_rate**div_res) + fluid.layers.assign(input=decayed_lr, output=lr) + + return lr + + +def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): + """ Applies linear learning rate warmup for distributed training + Argument learning_rate can be float or a Variable + lr = lr + (warmup_rate * step / warmup_steps) + """ + assert (isinstance(end_lr, float)) + assert (isinstance(start_lr, float)) + linear_step = end_lr - start_lr + with fluid.default_main_program()._lr_schedule_guard(): + lr = fluid.layers.tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate_warmup") + + global_step = fluid.layers.learning_rate_scheduler._decay_step_counter() + + with fluid.layers.control_flow.Switch() as switch: + with switch.case(global_step < warmup_steps): + decayed_lr = start_lr + linear_step * (global_step / + warmup_steps) + fluid.layers.tensor.assign(decayed_lr, lr) + with switch.default(): + fluid.layers.tensor.assign(learning_rate, lr) + + return lr + + +class Optimizer(object): + """A class used to represent several optimizer methods + + Attributes: + batch_size: batch size on all devices. + lr: learning rate. + lr_strategy: learning rate decay strategy. + l2_decay: l2_decay parameter. + momentum_rate: momentum rate when using Momentum optimizer. + step_epochs: piecewise decay steps. + num_epochs: number of total epochs. + + total_images: total images. + step: total steps in the an epoch. + + """ + + def __init__(self, args): + self.batch_size = args.batch_size + self.lr = args.lr + self.lr_strategy = args.lr_strategy + self.l2_decay = args.l2_decay + self.momentum_rate = args.momentum_rate + self.step_epochs = args.step_epochs + self.num_epochs = args.num_epochs + self.warm_up_epochs = args.warm_up_epochs + self.decay_epochs = args.decay_epochs + self.decay_rate = args.decay_rate + self.total_images = args.total_images + + self.step = int(math.ceil(float(self.total_images) / self.batch_size)) + + def piecewise_decay(self): + """piecewise decay with Momentum optimizer + + Returns: + a piecewise_decay optimizer + """ + bd = [self.step * e for e in self.step_epochs] + lr = [self.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=self.momentum_rate, + regularization=fluid.regularizer.L2Decay(self.l2_decay)) + return optimizer + + def cosine_decay(self): + """cosine decay with Momentum optimizer + + Returns: + a cosine_decay optimizer + """ + + learning_rate = fluid.layers.cosine_decay( + learning_rate=self.lr, + step_each_epoch=self.step, + epochs=self.num_epochs) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=self.momentum_rate, + regularization=fluid.regularizer.L2Decay(self.l2_decay)) + return optimizer + + def cosine_decay_warmup(self): + """cosine decay with warmup + + Returns: + a cosine_decay_with_warmup optimizer + """ + + learning_rate = cosine_decay_with_warmup( + learning_rate=self.lr, + step_each_epoch=self.step, + epochs=self.num_epochs) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=self.momentum_rate, + regularization=fluid.regularizer.L2Decay(self.l2_decay)) + return optimizer + + def exponential_decay_warmup(self): + """exponential decay with warmup + + Returns: + a exponential_decay_with_warmup optimizer + """ + + learning_rate = exponential_decay_with_warmup( + learning_rate=self.lr, + step_each_epoch=self.step, + decay_epochs=self.step * self.decay_epochs, + decay_rate=self.decay_rate, + warm_up_epoch=self.warm_up_epochs) + optimizer = fluid.optimizer.RMSProp( + learning_rate=learning_rate, + regularization=fluid.regularizer.L2Decay(self.l2_decay), + momentum=self.momentum_rate, + rho=0.9, + epsilon=0.001) + return optimizer + + def linear_decay(self): + """linear decay with Momentum optimizer + + Returns: + a linear_decay optimizer + """ + + end_lr = 0 + learning_rate = fluid.layers.polynomial_decay( + self.lr, self.step, end_lr, power=1) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=self.momentum_rate, + regularization=fluid.regularizer.L2Decay(self.l2_decay)) + + return optimizer + + def adam_decay(self): + """Adam optimizer + + Returns: + an adam_decay optimizer + """ + + return fluid.optimizer.Adam(learning_rate=self.lr) + + def cosine_decay_RMSProp(self): + """cosine decay with RMSProp optimizer + + Returns: + an cosine_decay_RMSProp optimizer + """ + + learning_rate = fluid.layers.cosine_decay( + learning_rate=self.lr, + step_each_epoch=self.step, + epochs=self.num_epochs) + optimizer = fluid.optimizer.RMSProp( + learning_rate=learning_rate, + momentum=self.momentum_rate, + regularization=fluid.regularizer.L2Decay(self.l2_decay), + # Apply epsilon=1 on ImageNet dataset. + epsilon=1) + return optimizer + + def default_decay(self): + """default decay + + Returns: + default decay optimizer + """ + + optimizer = fluid.optimizer.Momentum( + learning_rate=self.lr, + momentum=self.momentum_rate, + regularization=fluid.regularizer.L2Decay(self.l2_decay)) + return optimizer + + +def create_optimizer(args): + Opt = Optimizer(args) + optimizer = getattr(Opt, args.lr_strategy)() + + return optimizer diff --git a/dygraph/mobilenet/utils/utility.py b/dygraph/mobilenet/utils/utility.py new file mode 100644 index 00000000..53678ebb --- /dev/null +++ b/dygraph/mobilenet/utils/utility.py @@ -0,0 +1,576 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import distutils.util +import numpy as np +import six +import argparse +import functools +import logging +import sys +import os +import warnings +import signal + +import paddle +import paddle.fluid as fluid +from paddle.fluid.wrapped_decorator import signature_safe_contextmanager +from paddle.fluid.framework import Program, program_guard, name_scope, default_main_program +from paddle.fluid import unique_name, layers +from utils import dist_utils + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("------------- Configuration Arguments -------------") + for arg, value in sorted(six.iteritems(vars(args))): + print("%25s : %s" % (arg, value)) + print("----------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) + + +def parse_args(): + """Add arguments + + Returns: + all training args + """ + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + # yapf: disable + + add_arg('use_data_parallel', bool, False, "The flag indicating whether to use data parallel mode to train the model.") + add_arg('ce', bool, False, "run ce.") + + # ENV + add_arg('use_gpu', bool, True, "Whether to use GPU.") + add_arg('model_save_dir', str, "./output", "The directory path to save model.") + add_arg('data_dir', str, "../../PaddleCV/image_classification/data/ILSVRC2012/", "The ImageNet dataset root directory.") + #add_arg('data_dir', str, "../../PaddleCV/image_classification/data/", "The ImageNet dataset root directory.") + add_arg('pretrained_model', str, None, "Whether to load pretrained model.") + add_arg('checkpoint', str, None, "Whether to resume checkpoint.") + add_arg('print_step', int, 10, "The steps interval to print logs") + add_arg('save_step', int, 1, "The steps interval to save checkpoints") + + # SOLVER AND HYPERPARAMETERS + add_arg('model', str, "ResNet50", "The name of network.") + add_arg('total_images', int, 1281167, "The number of total training images.") + add_arg('num_epochs', int, 120, "The number of total epochs.") + add_arg('class_dim', int, 1000, "The number of total classes.") + add_arg('image_shape', str, "3,224,224", "The size of Input image, order: [channels, height, weidth] ") + add_arg('batch_size', int, 8, "Minibatch size on a device.") + add_arg('test_batch_size', int, 16, "Test batch size on a deveice.") + add_arg('lr', float, 0.1, "The learning rate.") + add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") + add_arg('l2_decay', float, 1e-4, "The l2_decay parameter.") + add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") + add_arg('warm_up_epochs', float, 5.0, "The value of warm up epochs") + add_arg('decay_epochs', float, 2.4, "Decay epochs of exponential decay learning rate scheduler") + add_arg('decay_rate', float, 0.97, "Decay rate of exponential decay learning rate scheduler") + add_arg('drop_connect_rate', float, 0.2, "The value of drop connect rate") + parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") + + # READER AND PREPROCESS + add_arg('lower_scale', float, 0.08, "The value of lower_scale in ramdom_crop") + add_arg('lower_ratio', float, 3./4., "The value of lower_ratio in ramdom_crop") + add_arg('upper_ratio', float, 4./3., "The value of upper_ratio in ramdom_crop") + add_arg('resize_short_size', int, 256, "The value of resize_short_size") + add_arg('crop_size', int, 224, "The value of crop size") + add_arg('use_mixup', bool, False, "Whether to use mixup") + add_arg('mixup_alpha', float, 0.2, "The value of mixup_alpha") + add_arg('reader_thread', int, 8, "The number of multi thread reader") + add_arg('reader_buf_size', int, 16, "The buf size of multi thread reader") + add_arg('interpolation', int, None, "The interpolation mode") + add_arg('use_aa', bool, False, "Whether to use auto augment") + parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456, 0.406], help="The mean of input image data") + parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data") + + # SWITCH + #NOTE: (2019/08/08) FP16 is moving to PaddlePaddle/Fleet now + #add_arg('use_fp16', bool, False, "Whether to enable half precision training with fp16." ) + #add_arg('scale_loss', float, 1.0, "The value of scale_loss for fp16." ) + add_arg('use_label_smoothing', bool, False, "Whether to use label_smoothing") + add_arg('label_smoothing_epsilon', float, 0.1, "The value of label_smoothing_epsilon parameter") + #NOTE: (2019/08/08) temporary disable use_distill + #add_arg('use_distill', bool, False, "Whether to use distill") + add_arg('random_seed', int, None, "random seed") + add_arg('use_ema', bool, False, "Whether to use ExponentialMovingAverage.") + add_arg('ema_decay', float, 0.9999, "The value of ema decay rate") + add_arg('padding_type', str, "SAME", "Padding type of convolution") + add_arg('use_se', bool, True, "Whether to use Squeeze-and-Excitation module for EfficientNet.") + # yapf: enable + + args = parser.parse_args() + + return args + + +def check_gpu(): + """ + Log error and exit when set use_gpu=true in paddlepaddle + cpu ver sion. + """ + logger = logging.getLogger(__name__) + err = "Config use_gpu cannot be set as true while you are " \ + "using paddlepaddle cpu version ! \nPlease try: \n" \ + "\t1. Install paddlepaddle-gpu to run model on GPU \n" \ + "\t2. Set use_gpu as false in config file to run " \ + "model on CPU" + + try: + if args.use_gpu and not fluid.is_compiled_with_cuda(): + print(err) + sys.exit(1) + except Exception as e: + pass + + +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + print(err) + sys.exit(1) + + +def check_args(args): + """check arguments before running + + Args: + all arguments + """ + + # check models name + sys.path.append("..") + import models + model_list = [m for m in dir(models) if "__" not in m] + assert args.model in model_list, "{} is not in lists: {}, please check the model name".format( + args.model, model_list) + + # check learning rate strategy + lr_strategy_list = [ + "piecewise_decay", "cosine_decay", "linear_decay", + "cosine_decay_warmup", "exponential_decay_warmup" + ] + if args.lr_strategy not in lr_strategy_list: + warnings.warn( + "\n{} is not in lists: {}, \nUse default learning strategy now.". + format(args.lr_strategy, lr_strategy_list)) + args.lr_strategy = "default_decay" + # check confict of GoogLeNet and mixup + if args.model == "GoogLeNet": + assert args.use_mixup == False, "Cannot use mixup processing in GoogLeNet, please set use_mixup = False." + + if args.interpolation: + assert args.interpolation in [ + 0, 1, 2, 3, 4 + ], "Wrong interpolation, please set:\n0: cv2.INTER_NEAREST\n1: cv2.INTER_LINEAR\n2: cv2.INTER_CUBIC\n3: cv2.INTER_AREA\n4: cv2.INTER_LANCZOS4" + + if args.padding_type: + assert args.padding_type in [ + "SAME", "VALID", "DYNAMIC" + ], "Wrong padding_type, please set:\nSAME\nVALID\nDYNAMIC" + + assert args.checkpoint is None or args.pretrained_model is None, "Do not init model by checkpoint and pretrained_model both." + + # check pretrained_model path for loading + if args.pretrained_model is not None: + assert isinstance(args.pretrained_model, str) + assert os.path.isdir( + args. + pretrained_model), "please support available pretrained_model path." + + #FIXME: check checkpoint path for saving + if args.checkpoint is not None: + assert isinstance(args.checkpoint, str) + assert os.path.isdir( + args.checkpoint + ), "please support available checkpoint path for initing model." + + # check params for loading + """ + if args.save_params: + assert isinstance(args.save_params, str) + assert os.path.isdir( + args.save_params), "please support available save_params path." + """ + + # check gpu: when using gpu, the number of visible cards should divide batch size + if args.use_gpu: + assert args.batch_size % fluid.core.get_cuda_device_count( + ) == 0, "please support correct batch_size({}), which can be divided by available cards({}), you can change the number of cards by indicating: export CUDA_VISIBLE_DEVICES= ".format( + args.batch_size, fluid.core.get_cuda_device_count()) + + # check data directory + assert os.path.isdir( + args.data_dir + ), "Data doesn't exist in {}, please load right path".format(args.data_dir) + + #check gpu + + check_gpu() + check_version() + + +def init_model(exe, args, program): + if args.checkpoint: + fluid.io.load_persistables(exe, args.checkpoint, main_program=program) + print("Finish initing model from %s" % (args.checkpoint)) + + if args.pretrained_model: + + def if_exist(var): + return os.path.exists(os.path.join(args.pretrained_model, var.name)) + + fluid.io.load_vars( + exe, + args.pretrained_model, + main_program=program, + predicate=if_exist) + + +def save_model(args, exe, train_prog, info): + model_path = os.path.join(args.model_save_dir, args.model, str(info)) + if not os.path.isdir(model_path): + os.makedirs(model_path) + fluid.io.save_persistables(exe, model_path, main_program=train_prog) + print("Already save model in %s" % (model_path)) + + +def create_data_loader(is_train, args): + """create data_loader + + Usage: + Using mixup process in training, it will return 5 results, include data_loader, image, y_a(label), y_b(label) and lamda, or it will return 3 results, include data_loader, image, and label. + + Args: + is_train: mode + args: arguments + + Returns: + data_loader and the input data of net, + """ + image_shape = [int(m) for m in args.image_shape.split(",")] + + feed_image = fluid.data( + name="feed_image", + shape=[None] + image_shape, + dtype="float32", + lod_level=0) + + feed_label = fluid.data( + name="feed_label", shape=[None, 1], dtype="int64", lod_level=0) + feed_y_a = fluid.data( + name="feed_y_a", shape=[None, 1], dtype="int64", lod_level=0) + + if is_train and args.use_mixup: + feed_y_b = fluid.data( + name="feed_y_b", shape=[None, 1], dtype="int64", lod_level=0) + feed_lam = fluid.data( + name="feed_lam", shape=[None, 1], dtype="float32", lod_level=0) + + data_loader = fluid.io.DataLoader.from_generator( + capacity=64, + use_double_buffer=True, + iterable=True, + return_list=True) + + return data_loader, [feed_image, feed_y_a, feed_y_b, feed_lam] + else: + data_loader = fluid.io.DataLoader.from_generator( + capacity=64, + use_double_buffer=True, + iterable=True, + return_list=True) + + return data_loader, [feed_image, feed_label] + + +def print_info(pass_id, batch_id, print_step, metrics, time_info, info_mode): + """print function + + Args: + pass_id: epoch index + batch_id: batch index + print_step: the print_step arguments + metrics: message to print + time_info: time infomation + info_mode: mode + """ + if info_mode == "batch": + if batch_id % print_step == 0: + #if isinstance(metrics,np.ndarray): + # train and mixup output + if len(metrics) == 2: + loss, lr = metrics + print( + "[Pass {0}, train batch {1}] \tloss {2}, lr {3}, elapse {4}". + format(pass_id, batch_id, "%.5f" % loss, "%.5f" % lr, + "%2.4f sec" % time_info)) + # train and no mixup output + elif len(metrics) == 4: + loss, acc1, acc5, lr = metrics + print( + "[Pass {0}, train batch {1}] \tloss {2}, acc1 {3}, acc5 {4}, lr {5}, elapse {6}". + format(pass_id, batch_id, "%.5f" % loss, "%.5f" % acc1, + "%.5f" % acc5, "%.5f" % lr, "%2.4f sec" % time_info)) + # test output + elif len(metrics) == 3: + loss, acc1, acc5 = metrics + print( + "[Pass {0}, test batch {1}] \tloss {2}, acc1 {3}, acc5 {4}, elapse {5}". + format(pass_id, batch_id, "%.5f" % loss, "%.5f" % acc1, + "%.5f" % acc5, "%2.4f sec" % time_info)) + else: + raise Exception( + "length of metrics {} is not implemented, It maybe caused by wrong format of build_program_output". + format(len(metrics))) + sys.stdout.flush() + + elif info_mode == "epoch": + ## TODO add time elapse + #if isinstance(metrics,np.ndarray): + if len(metrics) == 5: + train_loss, _, test_loss, test_acc1, test_acc5 = metrics + print( + "[End pass {0}]\ttrain_loss {1}, test_loss {2}, test_acc1 {3}, test_acc5 {4}". + format(pass_id, "%.5f" % train_loss, "%.5f" % test_loss, "%.5f" + % test_acc1, "%.5f" % test_acc5)) + elif len(metrics) == 7: + train_loss, train_acc1, train_acc5, _, test_loss, test_acc1, test_acc5 = metrics + print( + "[End pass {0}]\ttrain_loss {1}, train_acc1 {2}, train_acc5 {3},test_loss {4}, test_acc1 {5}, test_acc5 {6}". + format(pass_id, "%.5f" % train_loss, "%.5f" % train_acc1, "%.5f" + % train_acc5, "%.5f" % test_loss, "%.5f" % test_acc1, + "%.5f" % test_acc5)) + sys.stdout.flush() + elif info_mode == "ce": + raise Warning("CE code is not ready") + else: + raise Exception("Illegal info_mode") + + +def best_strategy_compiled(args, program, loss, exe): + """make a program which wrapped by a compiled program + """ + + if os.getenv('FLAGS_use_ngraph'): + return program + else: + build_strategy = fluid.compiler.BuildStrategy() + #Feature will be supported in Fluid v1.6 + #build_strategy.enable_inplace = True + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = fluid.core.get_cuda_device_count() + exec_strategy.num_iteration_per_drop_scope = 10 + + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + if num_trainers > 1 and args.use_gpu: + dist_utils.prepare_for_multi_process(exe, build_strategy, program) + # NOTE: the process is fast when num_threads is 1 + # for multi-process training. + exec_strategy.num_threads = 1 + + compiled_program = fluid.CompiledProgram(program).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + return compiled_program + + +class ExponentialMovingAverage(object): + def __init__(self, + decay=0.999, + thres_steps=None, + zero_debias=False, + name=None): + self._decay = decay + self._thres_steps = thres_steps + self._name = name if name is not None else '' + self._decay_var = self._get_ema_decay() + + self._params_tmps = [] + for param in default_main_program().global_block().all_parameters(): + if param.do_model_average != False: + tmp = param.block.create_var( + name=unique_name.generate(".".join( + [self._name + param.name, 'ema_tmp'])), + dtype=param.dtype, + persistable=False, + stop_gradient=True) + self._params_tmps.append((param, tmp)) + + self._ema_vars = {} + for param, tmp in self._params_tmps: + with param.block.program._optimized_guard( + [param, tmp]), name_scope('moving_average'): + self._ema_vars[param.name] = self._create_ema_vars(param) + + self.apply_program = Program() + block = self.apply_program.global_block() + with program_guard(main_program=self.apply_program): + decay_pow = self._get_decay_pow(block) + for param, tmp in self._params_tmps: + param = block._clone_variable(param) + tmp = block._clone_variable(tmp) + ema = block._clone_variable(self._ema_vars[param.name]) + layers.assign(input=param, output=tmp) + # bias correction + if zero_debias: + ema = ema / (1.0 - decay_pow) + layers.assign(input=ema, output=param) + + self.restore_program = Program() + block = self.restore_program.global_block() + with program_guard(main_program=self.restore_program): + for param, tmp in self._params_tmps: + tmp = block._clone_variable(tmp) + param = block._clone_variable(param) + layers.assign(input=tmp, output=param) + + def _get_ema_decay(self): + with default_main_program()._lr_schedule_guard(): + decay_var = layers.tensor.create_global_var( + shape=[1], + value=self._decay, + dtype='float32', + persistable=True, + name="scheduled_ema_decay_rate") + + if self._thres_steps is not None: + decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0) + with layers.control_flow.Switch() as switch: + with switch.case(decay_t < self._decay): + layers.tensor.assign(decay_t, decay_var) + with switch.default(): + layers.tensor.assign( + np.array( + [self._decay], dtype=np.float32), + decay_var) + return decay_var + + def _get_decay_pow(self, block): + global_steps = layers.learning_rate_scheduler._decay_step_counter() + decay_var = block._clone_variable(self._decay_var) + decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1) + return decay_pow_acc + + def _create_ema_vars(self, param): + param_ema = layers.create_global_var( + name=unique_name.generate(self._name + param.name + '_ema'), + shape=param.shape, + value=0.0, + dtype=param.dtype, + persistable=True) + + return param_ema + + def update(self): + """ + Update Exponential Moving Average. Should only call this method in + train program. + """ + param_master_emas = [] + for param, tmp in self._params_tmps: + with param.block.program._optimized_guard( + [param, tmp]), name_scope('moving_average'): + param_ema = self._ema_vars[param.name] + if param.name + '.master' in self._ema_vars: + master_ema = self._ema_vars[param.name + '.master'] + param_master_emas.append([param_ema, master_ema]) + else: + ema_t = param_ema * self._decay_var + param * ( + 1 - self._decay_var) + layers.assign(input=ema_t, output=param_ema) + + # for fp16 params + for param_ema, master_ema in param_master_emas: + default_main_program().global_block().append_op( + type="cast", + inputs={"X": master_ema}, + outputs={"Out": param_ema}, + attrs={ + "in_dtype": master_ema.dtype, + "out_dtype": param_ema.dtype + }) + + @signature_safe_contextmanager + def apply(self, executor, need_restore=True): + """ + Apply moving average to parameters for evaluation. + + Args: + executor (Executor): The Executor to execute applying. + need_restore (bool): Whether to restore parameters after applying. + """ + executor.run(self.apply_program) + try: + yield + finally: + if need_restore: + self.restore(executor) + + def restore(self, executor): + """Restore parameters. + + Args: + executor (Executor): The Executor to execute restoring. + """ + executor.run(self.restore_program) -- GitLab