提交 4ccfca29 编写于 作者: S shippingwang

add autoargument

上级 32ce6837
mode: 'train'
ARCHITECTURE:
name: "EfficientNetB0"
drop_connect_rate: 0.1
padding_type : "SAME"
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 360
topk: 5
image_shape: [3, 224, 224]
use_ema: True
ema_decay: 0.9999
use_aa: True
ls_epsilon: 0.1
LEARNING_RATE:
function: 'ExponentialWarmup'
params:
lr: 0.032
OPTIMIZER:
function: 'RMSProp'
params:
momentum: 0.9
rho: 0.9
epsilon: 0.001
regularizer:
function: 'L2'
factor: 0.00001
TRAIN:
batch_size: 512
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: Fals
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- AA:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 128
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
interpolation: 2
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
......@@ -25,6 +25,7 @@ import random
import cv2
import numpy as np
from autoargument import ImageNetPolicy
class OperatorParamError(ValueError):
""" OperatorParamError
......@@ -171,6 +172,18 @@ class RandFlipImage(object):
else:
return img
class AA(object):
def __init__(self):
self.policy = ImageNetPolicy()
def __call__(self,img):
from PIL import Image
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
img = self.policy(img)
img = np.asarray(img)
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
......
......@@ -145,6 +145,59 @@ class CosineWarmup(object):
return learning_rate
class ExponentialWarmup(object):
"""
Exponential learning rate decay with warmup
[0, warmup_epoch): linear warmup
[warmup_epoch, epochs): Exponential decay
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
decay_epochs(float): decay epochs
decay_rate(float): decay rate
warmup_epoch(int): epoch num of warmup
"""
def __init__(self, lr, step_each_epoch, decay_epochs=2.4, decay_rate=0.97, warmup_epoch=5, **kwargs):
super(CosineWarmup, self).__init__()
self.lr = lr
self.step_each_epoch = step_each_epoch
self.decay_epochs = decay_epochs * self.step_each_epoch
self.decay_rate = decay_rate
self.warmup_epoch = fluid.layers.fill_constant(
shape=[1],
value=float(warmup_epoch),
dtype='float32',
force_cpu=True)
def __call__(self):
global_step = _decay_step_counter()
learning_rate = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
epoch = ops.floor(global_step / self.step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < self.warmup_epoch):
decayed_lr = self.lr * \
(global_step / (self.step_each_epoch * self.warmup_epoch))
fluid.layers.tensor.assign(
input=decayed_lr, output=learning_rate)
with switch.default():
rest_step = global_step - self.warmup_epoch * self.step_each_epoch
div_res = ops.floor(rest_step / self.decay_epochs)
decayed_lr = self.lr*(self.decay_rate**div_res)
fluid.layers.tensor.assign(
input=decayed_lr, output=learning_rate)
return learning_rate
class LearningRateBuilder():
"""
Build learning rate variable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册