未验证 提交 c962e805 编写于 作者: R ruri 提交者: GitHub

Merge pull request #101 from shippingwang/add_efficientnet

add ema and EfficientNet
......@@ -33,6 +33,3 @@
- id: trailing-whitespace
files: \.(md|yml)$
- id: check-case-conflict
- id: flake8
args: ['--ignore=E265']
mode: 'train'
ARCHITECTURE:
name: "EfficientNetB0"
drop_connect_rate: 0.1
padding_type : "SAME"
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 360
topk: 5
image_shape: [3, 224, 224]
use_ema: True
ema_decay: 0.9999
use_aa: True
ls_epsilon: 0.1
LEARNING_RATE:
function: 'ExponentialWarmup'
params:
lr: 0.032
OPTIMIZER:
function: 'RMSProp'
params:
momentum: 0.9
rho: 0.9
epsilon: 0.001
regularizer:
function: 'L2'
factor: 0.00001
TRAIN:
batch_size: 512
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: Fals
channel_first: False
- RandCropImage:
size: 224
interpolation: 2
- RandFlipImage:
flip_code: 1
- AutoAugment:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 128
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
interpolation: 2
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
......@@ -25,6 +25,8 @@ import random
import cv2
import numpy as np
from .autoaugment import ImageNetPolicy
class OperatorParamError(ValueError):
""" OperatorParamError
......@@ -115,7 +117,9 @@ class CropImage(object):
class RandCropImage(object):
""" random crop image """
def __init__(self, size, scale=None, ratio=None):
def __init__(self, size, scale=None, ratio=None, interpolation=-1):
self.interpolation = interpolation if interpolation >= 0 else None
if type(size) is int:
self.size = (size, size) # (h, w)
else:
......@@ -149,7 +153,10 @@ class RandCropImage(object):
j = random.randint(0, img_h - h)
img = img[j:j + h, i:i + w, :]
return cv2.resize(img, size)
if self.interpolation is None:
return cv2.resize(img, size)
else:
return cv2.resize(img, size, interpolation=self.interpolation)
class RandFlipImage(object):
......@@ -172,6 +179,18 @@ class RandFlipImage(object):
return img
class AutoAugment(object):
def __init__(self):
self.policy = ImageNetPolicy()
def __call__(self, img):
from PIL import Image
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
img = self.policy(img)
img = np.asarray(img)
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
......
......@@ -383,7 +383,9 @@ class EfficientNet():
use_bias=True,
padding_type=self.padding_type,
name=name + '_se_expand')
se_out = inputs * fluid.layers.sigmoid(x_squeezed)
#se_out = inputs * fluid.layers.sigmoid(x_squeezed)
se_out = fluid.layers.elementwise_mul(
inputs, fluid.layers.sigmoid(x_squeezed), axis=-1)
return se_out
def extract_features(self, inputs, is_test):
......@@ -467,8 +469,8 @@ class BlockDecoder(object):
# Check stride
cond_1 = ('s' in options and len(options['s']) == 1)
cond_2 = ((len(options['s']) == 2)
and (options['s'][0] == options['s'][1]))
cond_2 = ((len(options['s']) == 2) and
(options['s'][0] == options['s'][1]))
assert (cond_1 or cond_2)
return BlockArgs(
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
......@@ -130,7 +130,7 @@ class CosineWarmup(object):
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < self.warmup_epoch):
decayed_lr = self.lr * \
(global_step / (self.step_each_epoch * self.warmup_epoch))
(global_step / (self.step_each_epoch * self.warmup_epoch))
fluid.layers.tensor.assign(
input=decayed_lr, output=learning_rate)
with switch.default():
......@@ -145,6 +145,65 @@ class CosineWarmup(object):
return learning_rate
class ExponentialWarmup(object):
"""
Exponential learning rate decay with warmup
[0, warmup_epoch): linear warmup
[warmup_epoch, epochs): Exponential decay
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
decay_epochs(float): decay epochs
decay_rate(float): decay rate
warmup_epoch(int): epoch num of warmup
"""
def __init__(self,
lr,
step_each_epoch,
decay_epochs=2.4,
decay_rate=0.97,
warmup_epoch=5,
**kwargs):
super(ExponentialWarmup, self).__init__()
self.lr = lr
self.step_each_epoch = step_each_epoch
self.decay_epochs = decay_epochs * self.step_each_epoch
self.decay_rate = decay_rate
self.warmup_epoch = fluid.layers.fill_constant(
shape=[1],
value=float(warmup_epoch),
dtype='float32',
force_cpu=True)
def __call__(self):
global_step = _decay_step_counter()
learning_rate = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
epoch = ops.floor(global_step / self.step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < self.warmup_epoch):
decayed_lr = self.lr * \
(global_step / (self.step_each_epoch * self.warmup_epoch))
fluid.layers.tensor.assign(
input=decayed_lr, output=learning_rate)
with switch.default():
rest_step = global_step - self.warmup_epoch * self.step_each_epoch
div_res = ops.floor(rest_step / self.decay_epochs)
decayed_lr = self.lr * (self.decay_rate**div_res)
fluid.layers.tensor.assign(
input=decayed_lr, output=learning_rate)
return learning_rate
class LearningRateBuilder():
"""
Build learning rate variable
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.framework import Program, program_guard, name_scope, default_main_program
from paddle.fluid import unique_name, layers
class ExponentialMovingAverage(object):
def __init__(self,
decay=0.999,
thres_steps=None,
zero_debias=False,
name=None):
self._decay = decay
self._thres_steps = thres_steps
self._name = name if name is not None else ''
self._decay_var = self._get_ema_decay()
self._params_tmps = []
for param in default_main_program().global_block().all_parameters():
if param.do_model_average != False:
tmp = param.block.create_var(
name=unique_name.generate(".".join(
[self._name + param.name, 'ema_tmp'])),
dtype=param.dtype,
persistable=False,
stop_gradient=True)
self._params_tmps.append((param, tmp))
self._ema_vars = {}
for param, tmp in self._params_tmps:
with param.block.program._optimized_guard(
[param, tmp]), name_scope('moving_average'):
self._ema_vars[param.name] = self._create_ema_vars(param)
self.apply_program = Program()
block = self.apply_program.global_block()
with program_guard(main_program=self.apply_program):
decay_pow = self._get_decay_pow(block)
for param, tmp in self._params_tmps:
param = block._clone_variable(param)
tmp = block._clone_variable(tmp)
ema = block._clone_variable(self._ema_vars[param.name])
layers.assign(input=param, output=tmp)
# bias correction
if zero_debias:
ema = ema / (1.0 - decay_pow)
layers.assign(input=ema, output=param)
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
for param, tmp in self._params_tmps:
tmp = block._clone_variable(tmp)
param = block._clone_variable(param)
layers.assign(input=tmp, output=param)
def _get_ema_decay(self):
with default_main_program()._lr_schedule_guard():
decay_var = layers.tensor.create_global_var(
shape=[1],
value=self._decay,
dtype='float32',
persistable=True,
name="scheduled_ema_decay_rate")
if self._thres_steps is not None:
decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0)
with layers.control_flow.Switch() as switch:
with switch.case(decay_t < self._decay):
layers.tensor.assign(decay_t, decay_var)
with switch.default():
layers.tensor.assign(
np.array(
[self._decay], dtype=np.float32),
decay_var)
return decay_var
def _get_decay_pow(self, block):
global_steps = layers.learning_rate_scheduler._decay_step_counter()
decay_var = block._clone_variable(self._decay_var)
decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1)
return decay_pow_acc
def _create_ema_vars(self, param):
param_ema = layers.create_global_var(
name=unique_name.generate(self._name + param.name + '_ema'),
shape=param.shape,
value=0.0,
dtype=param.dtype,
persistable=True)
return param_ema
def update(self):
"""
Update Exponential Moving Average. Should only call this method in
train program.
"""
param_master_emas = []
for param, tmp in self._params_tmps:
with param.block.program._optimized_guard(
[param, tmp]), name_scope('moving_average'):
param_ema = self._ema_vars[param.name]
if param.name + '.master' in self._ema_vars:
master_ema = self._ema_vars[param.name + '.master']
param_master_emas.append([param_ema, master_ema])
else:
ema_t = param_ema * self._decay_var + param * (
1 - self._decay_var)
layers.assign(input=ema_t, output=param_ema)
# for fp16 params
for param_ema, master_ema in param_master_emas:
default_main_program().global_block().append_op(
type="cast",
inputs={"X": master_ema},
outputs={"Out": param_ema},
attrs={
"in_dtype": master_ema.dtype,
"out_dtype": param_ema.dtype
})
@signature_safe_contextmanager
def apply(self, executor, need_restore=True):
"""
Apply moving average to parameters for evaluation.
Args:
executor (Executor): The Executor to execute applying.
need_restore (bool): Whether to restore parameters after applying.
"""
executor.run(self.apply_program)
try:
yield
finally:
if need_restore:
self.restore(executor)
def restore(self, executor):
"""Restore parameters.
Args:
executor (Executor): The Executor to execute restoring.
"""
executor.run(self.restore_program)
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import functools
import shutil
import sys
def main():
"""
Usage: when training with flag use_ema, and evaluating EMA model, should clean the saved model at first.
To generate clean model:
python ema_clean.py ema_model_dir cleaned_model_dir
"""
cleaned_model_dir = sys.argv[1]
ema_model_dir = sys.argv[2]
if not os.path.exists(cleaned_model_dir):
os.makedirs(cleaned_model_dir)
items = os.listdir(ema_model_dir)
for item in items:
if item.find('ema') > -1:
item_clean = item.replace('_ema_0', '')
shutil.copyfile(os.path.join(ema_model_dir, item),
os.path.join(cleaned_model_dir, item_clean))
elif item.find('mean') > -1 or item.find('variance') > -1:
shutil.copyfile(os.path.join(ema_model_dir, item),
os.path.join(cleaned_model_dir, item))
if __name__ == '__main__':
main()
......@@ -36,6 +36,8 @@ from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.collective import DistributedStrategy
from ema import ExponentialMovingAverage
def create_feeds(image_shape, use_mix=None):
"""
......@@ -86,7 +88,7 @@ def create_dataloader(feeds):
return dataloader
def create_model(architecture, image, classes_num):
def create_model(architecture, image, classes_num, is_train):
"""
Create a model
......@@ -101,6 +103,7 @@ def create_model(architecture, image, classes_num):
"""
name = architecture["name"]
params = architecture.get("params", {})
params['is_test'] = not is_train
model = architectures.__dict__[name](**params)
out = model.net(input=image, class_dim=classes_num)
return out
......@@ -336,7 +339,7 @@ def build(config, main_prog, startup_prog, is_train=True):
feeds = create_feeds(config.image_shape, use_mix=use_mix)
dataloader = create_dataloader(feeds.values())
out = create_model(config.ARCHITECTURE, feeds['image'],
config.classes_num)
config.classes_num, is_train)
fetchs = create_fetchs(
out,
feeds,
......@@ -354,6 +357,14 @@ def build(config, main_prog, startup_prog, is_train=True):
optimizer = mixed_precision_optimizer(config, optimizer)
optimizer = dist_optimizer(config, optimizer)
optimizer.minimize(fetchs['loss'][0])
if config.get('use_ema'):
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
)
ema = ExponentialMovingAverage(
config.get('ema_decay'), thres_steps=global_steps)
ema.update()
return dataloader, fetchs, ema
return dataloader, fetchs
......@@ -387,7 +398,13 @@ def compile(config, program, loss_name=None):
total_step = 0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
def run(dataloader,
exe,
program,
fetchs,
epoch=0,
mode='train',
vdl_writer=None):
"""
Feed data to the model and fetch the measures and loss
......@@ -401,6 +418,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None
Returns:
"""
print(fetchs)
fetch_list = [f[0] for f in fetchs.values()]
metric_list = [f[1] for f in fetchs.values()]
for m in metric_list:
......
......@@ -70,8 +70,12 @@ def main(args):
best_top1_acc = 0.0 # best top1 acc record
train_dataloader, train_fetchs = program.build(
config, train_prog, startup_prog, is_train=True)
if not config.get('use_ema'):
train_dataloader, train_fetchs = program.build(
config, train_prog, startup_prog, is_train=True)
else:
train_dataloader, train_fetchs, ema = program.build(
config, train_prog, startup_prog, is_train=True)
if config.validate:
valid_prog = fluid.Program()
......@@ -81,11 +85,11 @@ def main(args):
valid_prog = valid_prog.clone(for_test=True)
# create the "Executor" with the statement of which place
exe = fluid.Executor(place=place)
# only run startup_prog once to init
exe = fluid.Executor(place)
# Parameter initialization
exe.run(startup_prog)
# load model from checkpoint or pretrained model
# load model from 1. checkpoint to resume training, 2. pretrained model to finetune
init_model(config, train_prog, exe)
train_reader = Reader(config, 'train')()
......@@ -106,6 +110,14 @@ def main(args):
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
if config.get('use_ema'):
logger.info(logger.coloring("EMA validate start..."))
with train_fetchs('ema').apply(exe):
top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog,
valid_fetchs, epoch_id, 'valid')
logger.info(logger.coloring("EMA validate over!"))
top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_fetchs,
epoch_id, 'valid')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册