提交 86ce1af7 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: ruri

fix warmup bug in cosine decay (#4217)

上级 80cbdf27
......@@ -35,7 +35,10 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120):
return decayed_lr
def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120):
def cosine_decay_with_warmup(learning_rate,
step_each_epoch,
epochs=120,
warm_up_epoch=5.0):
"""Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
decrease lr for every mini-batch and start with warmup.
......@@ -49,7 +52,7 @@ def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120):
name="learning_rate")
warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(5), force_cpu=True)
shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True)
epoch = ops.floor(global_step / step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
......@@ -59,20 +62,25 @@ def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120):
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
with switch.default():
decayed_lr = learning_rate * \
(ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / (epochs * step_each_epoch))) + 1)/2
(ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / ((epochs-warmup_epoch) * step_each_epoch))) + 1)/2
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
return lr
def exponential_decay_with_warmup(learning_rate, step_each_epoch, decay_epochs, decay_rate=0.97, warm_up_epoch=5.0):
def exponential_decay_with_warmup(learning_rate,
step_each_epoch,
decay_epochs,
decay_rate=0.97,
warm_up_epoch=5.0):
"""Applies exponential decay to the learning rate.
"""
global_step = _decay_step_counter()
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True)
......@@ -80,16 +88,19 @@ def exponential_decay_with_warmup(learning_rate, step_each_epoch, decay_epochs,
epoch = ops.floor(global_step / step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < warmup_epoch):
decayed_lr = learning_rate * (global_step / (step_each_epoch * warmup_epoch))
decayed_lr = learning_rate * (global_step /
(step_each_epoch * warmup_epoch))
fluid.layers.assign(input=decayed_lr, output=lr)
with switch.default():
div_res = (global_step - warmup_epoch * step_each_epoch) / decay_epochs
div_res = (
global_step - warmup_epoch * step_each_epoch) / decay_epochs
div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate ** div_res)
decayed_lr = learning_rate * (decay_rate**div_res)
fluid.layers.assign(input=decayed_lr, output=lr)
return lr
def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
""" Applies linear learning rate warmup for distributed training
Argument learning_rate can be float or a Variable
......@@ -193,7 +204,8 @@ class Optimizer(object):
learning_rate = cosine_decay_with_warmup(
learning_rate=self.lr,
step_each_epoch=self.step,
epochs=self.num_epochs)
epochs=self.num_epochs,
warm_up_epoch=self.warm_up_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
......@@ -218,8 +230,7 @@ class Optimizer(object):
regularization=fluid.regularizer.L2Decay(self.l2_decay),
momentum=self.momentum_rate,
rho=0.9,
epsilon=0.001
)
epsilon=0.001)
return optimizer
def linear_decay(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册