From 4ccfca291b09a01cb86725c7b49aa7f910f27adf Mon Sep 17 00:00:00 2001 From: shippingwang Date: Wed, 6 May 2020 11:17:39 +0000 Subject: [PATCH] add autoargument --- configs/EfficientNet/EfficientNetB0.yaml | 84 ++++++++++++++++++++++++ ppcls/data/imaug/operators.py | 13 ++++ ppcls/optimizer/learning_rate.py | 53 +++++++++++++++ 3 files changed, 150 insertions(+) create mode 100644 configs/EfficientNet/EfficientNetB0.yaml diff --git a/configs/EfficientNet/EfficientNetB0.yaml b/configs/EfficientNet/EfficientNetB0.yaml new file mode 100644 index 00000000..ae28a3cc --- /dev/null +++ b/configs/EfficientNet/EfficientNetB0.yaml @@ -0,0 +1,84 @@ +mode: 'train' +ARCHITECTURE: + name: "EfficientNetB0" + drop_connect_rate: 0.1 + padding_type : "SAME" +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 360 +topk: 5 +image_shape: [3, 224, 224] +use_ema: True +ema_decay: 0.9999 +use_aa: True +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'ExponentialWarmup' + params: + lr: 0.032 + +OPTIMIZER: + function: 'RMSProp' + params: + momentum: 0.9 + rho: 0.9 + epsilon: 0.001 + regularizer: + function: 'L2' + factor: 0.00001 + +TRAIN: + batch_size: 512 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: Fals + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AA: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + + + +VALID: + batch_size: 128 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + interpolation: 2 + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + + diff --git a/ppcls/data/imaug/operators.py b/ppcls/data/imaug/operators.py index a8454740..7da40697 100644 --- a/ppcls/data/imaug/operators.py +++ b/ppcls/data/imaug/operators.py @@ -25,6 +25,7 @@ import random import cv2 import numpy as np +from autoargument import ImageNetPolicy class OperatorParamError(ValueError): """ OperatorParamError @@ -171,6 +172,18 @@ class RandFlipImage(object): else: return img +class AA(object): + + def __init__(self): + self.policy = ImageNetPolicy() + + def __call__(self,img): + from PIL import Image + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + img = self.policy(img) + img = np.asarray(img) + class NormalizeImage(object): """ normalize image such as substract mean, divide std diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index 197f8af1..b6de18ba 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -145,6 +145,59 @@ class CosineWarmup(object): return learning_rate +class ExponentialWarmup(object): + + """ + Exponential learning rate decay with warmup + [0, warmup_epoch): linear warmup + [warmup_epoch, epochs): Exponential decay + + Args: + lr(float): initial learning rate + step_each_epoch(int): steps each epoch + decay_epochs(float): decay epochs + decay_rate(float): decay rate + warmup_epoch(int): epoch num of warmup + """ + + def __init__(self, lr, step_each_epoch, decay_epochs=2.4, decay_rate=0.97, warmup_epoch=5, **kwargs): + super(CosineWarmup, self).__init__() + self.lr = lr + self.step_each_epoch = step_each_epoch + self.decay_epochs = decay_epochs * self.step_each_epoch + self.decay_rate = decay_rate + self.warmup_epoch = fluid.layers.fill_constant( + shape=[1], + value=float(warmup_epoch), + dtype='float32', + force_cpu=True) + + def __call__(self): + global_step = _decay_step_counter() + learning_rate = fluid.layers.tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate") + + epoch = ops.floor(global_step / self.step_each_epoch) + with fluid.layers.control_flow.Switch() as switch: + with switch.case(epoch < self.warmup_epoch): + decayed_lr = self.lr * \ + (global_step / (self.step_each_epoch * self.warmup_epoch)) + fluid.layers.tensor.assign( + input=decayed_lr, output=learning_rate) + with switch.default(): + rest_step = global_step - self.warmup_epoch * self.step_each_epoch + div_res = ops.floor(rest_step / self.decay_epochs) + + decayed_lr = self.lr*(self.decay_rate**div_res) + fluid.layers.tensor.assign( + input=decayed_lr, output=learning_rate) + + return learning_rate + class LearningRateBuilder(): """ Build learning rate variable -- GitLab