未验证 提交 48848658 编写于 作者: D dyning 提交者: GitHub

Merge pull request #587 from littletomatodonkey/fix_cpp_repl

add warmup
...@@ -63,8 +63,9 @@ ...@@ -63,8 +63,9 @@
| beta1 | 设置一阶矩估计的指数衰减率 | 0.9 | \ | | beta1 | 设置一阶矩估计的指数衰减率 | 0.9 | \ |
| beta2 | 设置二阶矩估计的指数衰减率 | 0.999 | \ | | beta2 | 设置二阶矩估计的指数衰减率 | 0.999 | \ |
| decay | 是否使用decay | \ | \ | | decay | 是否使用decay | \ | \ |
| function(decay) | 设置decay方式 | - | 目前支持cosine_decay与piecewise_decay | | function(decay) | 设置decay方式 | - | 目前支持cosine_decay, cosine_decay_warmup与piecewise_decay |
| step_each_epoch | 每个epoch包含多少次迭代, cosine_decay时有效 | 20 | 计算方式:total_image_num / (batch_size_per_card * card_size) | | step_each_epoch | 每个epoch包含多少次迭代, cosine_decay/cosine_decay_warmup时有效 | 20 | 计算方式:total_image_num / (batch_size_per_card * card_size) |
| total_epoch | 总共迭代多少个epoch, cosine_decay时有效 | 1000 | 与Global.epoch_num 一致 | | total_epoch | 总共迭代多少个epoch, cosine_decay/cosine_decay_warmup时有效 | 1000 | 与Global.epoch_num 一致 |
| warmup_minibatch | 线性warmup的迭代次数, cosine_decay_warmup时有效 | 1000 | \ |
| boundaries | 学习率下降时的迭代次数间隔, piecewise_decay时有效 | - | 参数为列表形式 | | boundaries | 学习率下降时的迭代次数间隔, piecewise_decay时有效 | - | 参数为列表形式 |
| decay_rate | 学习率衰减系数, piecewise_decay时有效 | - | \ | | decay_rate | 学习率衰减系数, piecewise_decay时有效 | - | \ |
...@@ -60,8 +60,9 @@ Take `rec_icdar15_train.yml` as an example: ...@@ -60,8 +60,9 @@ Take `rec_icdar15_train.yml` as an example:
| beta1 | Set the exponential decay rate for the 1st moment estimates | 0.9 | \ | | beta1 | Set the exponential decay rate for the 1st moment estimates | 0.9 | \ |
| beta2 | Set the exponential decay rate for the 2nd moment estimates | 0.999 | \ | | beta2 | Set the exponential decay rate for the 2nd moment estimates | 0.999 | \ |
| decay | Whether to use decay | \ | \ | | decay | Whether to use decay | \ | \ |
| function(decay) | Set the decay function | cosine_decay | Support cosine_decay and piecewise_decay | | function(decay) | Set the decay function | cosine_decay | Support cosine_decay, cosine_decay_warmup and piecewise_decay |
| step_each_epoch | The number of steps in an epoch. Used in cosine_decay | 20 | Calculation :total_image_num / (batch_size_per_card * card_size) | | step_each_epoch | The number of steps in an epoch. Used in cosine_decay/cosine_decay_warmup | 20 | Calculation: total_image_num / (batch_size_per_card * card_size) |
| total_epoch | The number of epochs. Used in cosine_decay | 1000 | Consistent with Global.epoch_num | | total_epoch | The number of epochs. Used in cosine_decay/cosine_decay_warmup | 1000 | Consistent with Global.epoch_num |
| warmup_minibatch | Number of steps for linear warmup. Used in cosine_decay_warmup | 1000 | \ |
| boundaries | The step intervals to reduce learning rate. Used in piecewise_decay | - | The format is list | | boundaries | The step intervals to reduce learning rate. Used in piecewise_decay | - | The format is list |
| decay_rate | Learning rate decay rate. Used in piecewise_decay | - | \ | | decay_rate | Learning rate decay rate. Used in piecewise_decay | - | \ |
...@@ -14,14 +14,50 @@ ...@@ -14,14 +14,50 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
import paddle.fluid.layers.ops as ops
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
def cosine_decay_with_warmup(learning_rate,
step_each_epoch,
epochs=500,
warmup_minibatch=1000):
"""Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
decrease lr for every mini-batch and start with warmup.
"""
global_step = _decay_step_counter()
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_minibatch = fluid.layers.fill_constant(
shape=[1],
dtype='float32',
value=float(warmup_minibatch),
force_cpu=True)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_minibatch):
decayed_lr = learning_rate * (1.0 * global_step / warmup_minibatch)
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
with switch.default():
decayed_lr = learning_rate * \
(ops.cos((global_step - warmup_minibatch) * (math.pi / (epochs * step_each_epoch))) + 1)/2
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
return lr
def AdamDecay(params, parameter_list=None): def AdamDecay(params, parameter_list=None):
""" """
define optimizer function define optimizer function
...@@ -36,7 +72,9 @@ def AdamDecay(params, parameter_list=None): ...@@ -36,7 +72,9 @@ def AdamDecay(params, parameter_list=None):
l2_decay = params.get("l2_decay", 0.0) l2_decay = params.get("l2_decay", 0.0)
if 'decay' in params: if 'decay' in params:
supported_decay_mode = ["cosine_decay", "piecewise_decay"] supported_decay_mode = [
"cosine_decay", "cosine_decay_warmup", "piecewise_decay"
]
params = params['decay'] params = params['decay']
decay_mode = params['function'] decay_mode = params['function']
assert decay_mode in supported_decay_mode, "Supported decay mode is {}, but got {}".format( assert decay_mode in supported_decay_mode, "Supported decay mode is {}, but got {}".format(
...@@ -49,6 +87,15 @@ def AdamDecay(params, parameter_list=None): ...@@ -49,6 +87,15 @@ def AdamDecay(params, parameter_list=None):
learning_rate=base_lr, learning_rate=base_lr,
step_each_epoch=step_each_epoch, step_each_epoch=step_each_epoch,
epochs=total_epoch) epochs=total_epoch)
elif decay_mode == "cosine_decay_warmup":
step_each_epoch = params['step_each_epoch']
total_epoch = params['total_epoch']
warmup_minibatch = params.get("warmup_minibatch", 1000)
base_lr = cosine_decay_with_warmup(
learning_rate=base_lr,
step_each_epoch=step_each_epoch,
epochs=total_epoch,
warmup_minibatch=warmup_minibatch)
elif decay_mode == "piecewise_decay": elif decay_mode == "piecewise_decay":
boundaries = params["boundaries"] boundaries = params["boundaries"]
decay_rate = params["decay_rate"] decay_rate = params["decay_rate"]
...@@ -104,5 +151,5 @@ def RMSProp(params, parameter_list=None): ...@@ -104,5 +151,5 @@ def RMSProp(params, parameter_list=None):
optimizer = fluid.optimizer.RMSProp( optimizer = fluid.optimizer.RMSProp(
learning_rate=base_lr, learning_rate=base_lr,
regularization=fluid.regularizer.L2Decay(regularization_coeff=l2_decay)) regularization=fluid.regularizer.L2Decay(regularization_coeff=l2_decay))
return optimizer return optimizer
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册