提交 99128a5c 编写于 作者: M minqiyang

Implement Cosine and Noam Decay

test=develop
上级 ec9c0874
...@@ -14,10 +14,13 @@ ...@@ -14,10 +14,13 @@
from __future__ import print_function from __future__ import print_function
import math
from .. import unique_name from .. import unique_name
__all__ = [ __all__ = [
'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay', 'InverseTimeDecay' 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
'InverseTimeDecay', 'CosineDecay'
] ]
...@@ -34,7 +37,7 @@ class LearningRateDecay(object): ...@@ -34,7 +37,7 @@ class LearningRateDecay(object):
def __call__(self): def __call__(self):
lr = self.step() lr = self.step()
if isinstance(lr, float): if isinstance(lr, float):
lr = self._create_lr_var(lr) lr = self.create_lr_var(lr)
self.step_num += self.step_size self.step_num += self.step_size
return lr return lr
...@@ -166,18 +169,58 @@ class PolynomialDecay(LearningRateDecay): ...@@ -166,18 +169,58 @@ class PolynomialDecay(LearningRateDecay):
def step(self): def step(self):
from .. import layers from .. import layers
tmp_step_num = self.step_num
tmp_decay_steps = self.decay_steps
if self.cycle: if self.cycle:
div_res = layers.ceil( div_res = layers.ceil(
self.create_lr_var(self.step_num / self.decay_steps)) self.create_lr_var(tmp_step_num / self.decay_steps))
zero_var = 0.0 zero_var = 0.0
one_var = 1.0 one_var = 1.0
if float(self.step_num) == zero_var: if float(tmp_step_num) == zero_var:
div_res = one_var div_res = one_var
decay_steps = self.decay_steps * div_res tmp_decay_steps = self.decay_steps * div_res
else: else:
global_step = global_step if global_step < self.decay_steps else self.decay_steps tmp_step_num = self.create_lr_var(tmp_step_num
if tmp_step_num < self.decay_steps
else self.decay_steps)
decayed_lr = (self.learning_rate - self.end_learning_rate) * \
((1 - tmp_step_num / tmp_decay_steps) ** self.power) + self.end_learning_rate
return decayed_lr
decayed_lr = (self.learning_rate - self.end_learning_rate) * \
((1 - global_step / self.decay_steps) ** self.power) + self.end_learning_rate class CosineDecay(LearningRateDecay):
return self.create_lr_var(decayed_lr) def __init__(self,
learning_rate,
step_each_epoch,
epochs,
begin=0,
step=1,
dtype='float32'):
super(CosineDecay, self).__init__(begin, step, dtype)
self.learning_rate = learning_rate
self.step_each_epoch = step_each_epoch
self.epochs = epochs
def step(self):
from .. import layers
cur_epoch = layers.floor(
self.create_lr_var(self.step_num / self.step_each_epoch))
decayed_lr = self.learning_rate * 0.5 * (
layers.cos(cur_epoch * math.pi / self.epochs) + 1)
return decayed_lr
class NoamDecay(LearningRateDecay):
def __init__(self, d_model, warmup_steps, begin=1, step=1, dtype='float32'):
super(NoamDecay, self).__init__(begin, step, dtype)
self.d_model = d_model
self.warmup_steps = warmup_steps
def step(self):
from .. import layers
a = self.create_lr_var(global_step**-0.5)
b = self.create_lr_var((warmup_steps**-1.5) * global_step)
lr_value = (d_model**-0.5) * layers.elementwise_min(a, b)
return lr_value
...@@ -69,13 +69,17 @@ def noam_decay(d_model, warmup_steps): ...@@ -69,13 +69,17 @@ def noam_decay(d_model, warmup_steps):
The decayed learning rate. The decayed learning rate.
""" """
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter(1) if imperative_base.enabled():
decay = imperate_lr.NoamDecay(d_model, warmup_steps)
return decay
else:
global_step = _decay_step_counter(1)
a = global_step**-0.5 a = global_step**-0.5
b = (warmup_steps**-1.5) * global_step b = (warmup_steps**-1.5) * global_step
lr_value = (d_model**-0.5) * nn.elementwise_min(a, b) lr_value = (d_model**-0.5) * nn.elementwise_min(a, b)
return lr_value return lr_value
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
...@@ -364,12 +368,17 @@ def cosine_decay(learning_rate, step_each_epoch, epochs): ...@@ -364,12 +368,17 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
learning_rate = base_lr, step_each_epoch=10000, epochs=120) learning_rate = base_lr, step_each_epoch=10000, epochs=120)
""" """
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter() if imperative_base.enabled():
decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch,
epochs)
return decay
else:
global_step = _decay_step_counter()
cur_epoch = ops.floor(global_step / step_each_epoch) cur_epoch = ops.floor(global_step / step_each_epoch)
decayed_lr = learning_rate * 0.5 * ( decayed_lr = learning_rate * 0.5 * (
ops.cos(cur_epoch * math.pi / epochs) + 1) ops.cos(cur_epoch * math.pi / epochs) + 1)
return decayed_lr return decayed_lr
def append_LARS(params_grads, learning_rate, weight_decay): def append_LARS(params_grads, learning_rate, weight_decay):
...@@ -391,6 +400,9 @@ def append_LARS(params_grads, learning_rate, weight_decay): ...@@ -391,6 +400,9 @@ def append_LARS(params_grads, learning_rate, weight_decay):
/ (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param))) / (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param)))
""" """
assert not imperative_base.enabled(
), "append_LARS is NOT supported in dygraph mode now"
def _balanced_weight(param_norm, grad_norm): def _balanced_weight(param_norm, grad_norm):
if weight_decay == 1.0: if weight_decay == 1.0:
return grad_norm + param_norm return grad_norm + param_norm
......
...@@ -195,6 +195,8 @@ class Optimizer(object): ...@@ -195,6 +195,8 @@ class Optimizer(object):
name = self._name + "_" + name name = self._name + "_" + name
if (name in self._accumulators and if (name in self._accumulators and
param.name in self._accumulators[name]): param.name in self._accumulators[name]):
if framework._in_imperative_mode():
return self._accumulators[name][param.name]
raise Exception("Accumulator {} already exists for parameter {}". raise Exception("Accumulator {} already exists for parameter {}".
format(name, param.name)) format(name, param.name))
if shape == None: if shape == None:
......
...@@ -43,7 +43,7 @@ class MLP(fluid.imperative.Layer): ...@@ -43,7 +43,7 @@ class MLP(fluid.imperative.Layer):
class TestImperativeOptimizerBase(unittest.TestCase): class TestImperativeOptimizerBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.batch_num = 10 self.batch_num = 20
def get_optimizer(self): def get_optimizer(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -214,5 +214,25 @@ class TestImperativeOptimizerPolynomialDecay(TestImperativeOptimizerBase): ...@@ -214,5 +214,25 @@ class TestImperativeOptimizerPolynomialDecay(TestImperativeOptimizerBase):
self._check_mlp() self._check_mlp()
class TestImperativeOptimizerCosineDecay(TestImperativeOptimizerBase):
def get_optimizer(self):
optimizer = SGDOptimizer(learning_rate=fluid.layers.cosine_decay(
learning_rate=0.1, step_each_epoch=10000, epochs=120))
return optimizer
def test_sgd(self):
self._check_mlp()
class TestImperativeOptimizerNoamDecay(TestImperativeOptimizerBase):
def get_optimizer(self):
optimizer = SGDOptimizer(learning_rate=fluid.layers.noam_decay(
d_model=512, warmup_steps=8000))
return optimizer
def test_sgd(self):
self._check_mlp()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册