提交 175cf6e0 编写于 作者: Y Yu Yang

Add global_step in nn.py

上级 95ea54fd
...@@ -330,9 +330,28 @@ class LayerHelper(object): ...@@ -330,9 +330,28 @@ class LayerHelper(object):
return self.main_program.current_block().create_var(*args, **kwargs) return self.main_program.current_block().create_var(*args, **kwargs)
def create_global_variable(self, persistable=False, *args, **kwargs): def create_global_variable(self, persistable=False, *args, **kwargs):
"""
create global variable, note that there is no initializer for this global variable.
Args:
persistable(bool): True if it is a checkpoint value.
*args: See create_var's documentation
**kwargs: See create_var's documentation
Returns(Variable): the created variable.
"""
return self.main_program.global_block().create_var( return self.main_program.global_block().create_var(
*args, persistable=persistable, **kwargs) *args, persistable=persistable, **kwargs)
def create_or_get_global_variable(self, name, *args, **kwargs):
"""
Creates a global variable if not exists and returns the variable and
a boolean flag which is true when it is a new variable.
"""
if self.main_program.global_block().has_var(name):
return self.main_program.global_block().var(name), False
else:
return self.create_global_variable(name=name, *args, **kwargs), True
def set_variable_initializer(self, var, initializer): def set_variable_initializer(self, var, initializer):
assert isinstance(var, Variable) assert isinstance(var, Variable)
self.startup_program.global_block().create_var( self.startup_program.global_block().create_var(
......
...@@ -69,6 +69,7 @@ __all__ = [ ...@@ -69,6 +69,7 @@ __all__ = [
'softmax_with_cross_entropy', 'softmax_with_cross_entropy',
'smooth_l1', 'smooth_l1',
'one_hot', 'one_hot',
'global_step_counter',
] ]
...@@ -3250,3 +3251,25 @@ def one_hot(input, depth): ...@@ -3250,3 +3251,25 @@ def one_hot(input, depth):
attrs={'depth': depth}, attrs={'depth': depth},
outputs={'Out': one_hot_out}) outputs={'Out': one_hot_out})
return one_hot_out return one_hot_out
def global_step_counter():
"""
Return the run counter of the main program, which is started with 1.
Returns(Variable): The global run counter.
"""
helper = LayerHelper('global_step_counter')
counter_name = '@STEP_COUNTER@'
counter, is_new_var = helper.create_or_get_global_variable(
name=counter_name, dtype='int64', shape=[1], persistable=True)
if is_new_var:
helper.set_variable_initializer(
counter, initializer=Constant(
value=0, force_cpu=True))
helper.main_program.global_block().prepend_op(
type='increment',
inputs={'X': [counter]},
outputs={'Out': [counter]})
counter.stop_gradient = True
return counter
...@@ -30,11 +30,14 @@ strategy according to this module. ...@@ -30,11 +30,14 @@ strategy according to this module.
""" """
def exponential_decay(learning_rate, def float_global_step():
global_step, # the first global step is zero in learning rate decay
decay_steps, global_step = layers.global_step_counter() - 1
decay_rate, global_step = layers.cast(global_step, 'float32')
staircase=False): return global_step
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""Applies exponential decay to the learning rate. """Applies exponential decay to the learning rate.
```python ```python
...@@ -44,7 +47,6 @@ def exponential_decay(learning_rate, ...@@ -44,7 +47,6 @@ def exponential_decay(learning_rate,
Args: Args:
learning_rate: A scalar float32 value or a Variable. This learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number. decay_steps: A Python `int32` number.
decay_rate: A Python `float` number. decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps. staircase: Boolean. If set true, decay the learning rate every decay_steps.
...@@ -52,8 +54,7 @@ def exponential_decay(learning_rate, ...@@ -52,8 +54,7 @@ def exponential_decay(learning_rate,
Returns: Returns:
The decayed learning rate The decayed learning rate
""" """
if not isinstance(global_step, Variable): global_step = float_global_step()
raise ValueError("global_step is required for exponential_decay.")
with init_on_cpu(): with init_on_cpu():
# update learning_rate # update learning_rate
...@@ -65,23 +66,17 @@ def exponential_decay(learning_rate, ...@@ -65,23 +66,17 @@ def exponential_decay(learning_rate,
return decayed_lr return decayed_lr
def natural_exp_decay(learning_rate, def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
global_step,
decay_steps,
decay_rate,
staircase=False):
"""Applies natural exponential decay to the initial learning rate. """Applies natural exponential decay to the initial learning rate.
```python >>> if not staircase:
if not staircase: >>> decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps)) >>> else:
else: >>> decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
```
Args: Args:
learning_rate: A scalar float32 value or a Variable. This learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number. decay_steps: A Python `int32` number.
decay_rate: A Python `float` number. decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps. staircase: Boolean. If set true, decay the learning rate every decay_steps.
...@@ -89,8 +84,7 @@ def natural_exp_decay(learning_rate, ...@@ -89,8 +84,7 @@ def natural_exp_decay(learning_rate,
Returns: Returns:
The decayed learning rate The decayed learning rate
""" """
if not isinstance(global_step, Variable): global_step = float_global_step()
raise ValueError("global_step is required for natural_exp_decay.")
with init_on_cpu(): with init_on_cpu():
div_res = global_step / decay_steps div_res = global_step / decay_steps
...@@ -101,23 +95,17 @@ def natural_exp_decay(learning_rate, ...@@ -101,23 +95,17 @@ def natural_exp_decay(learning_rate,
return decayed_lr return decayed_lr
def inverse_time_decay(learning_rate, def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
global_step,
decay_steps,
decay_rate,
staircase=False):
"""Applies inverse time decay to the initial learning rate. """Applies inverse time decay to the initial learning rate.
```python >>> if staircase:
if staircase: >>> decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step)) >>> else:
else: >>> decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)
decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)
```
Args: Args:
learning_rate: A scalar float32 value or a Variable. This learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training will be the initial learning rate during training.
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number. decay_steps: A Python `int32` number.
decay_rate: A Python `float` number. decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps. staircase: Boolean. If set true, decay the learning rate every decay_steps.
...@@ -125,8 +113,7 @@ def inverse_time_decay(learning_rate, ...@@ -125,8 +113,7 @@ def inverse_time_decay(learning_rate,
Returns: Returns:
The decayed learning rate The decayed learning rate
""" """
if not isinstance(global_step, Variable): global_step = float_global_step()
raise ValueError("global_step is required for inverse_time_decay.")
with init_on_cpu(): with init_on_cpu():
div_res = global_step / decay_steps div_res = global_step / decay_steps
...@@ -139,26 +126,22 @@ def inverse_time_decay(learning_rate, ...@@ -139,26 +126,22 @@ def inverse_time_decay(learning_rate,
def polynomial_decay(learning_rate, def polynomial_decay(learning_rate,
global_step,
decay_steps, decay_steps,
end_learning_rate=0.0001, end_learning_rate=0.0001,
power=1.0, power=1.0,
cycle=False): cycle=False):
"""Applies polynomial decay to the initial learning rate. """Applies polynomial decay to the initial learning rate.
```python >>> if cycle:
if cycle: >>> decay_steps = decay_steps * ceil(global_step / decay_steps)
decay_steps = decay_steps * ceil(global_step / decay_steps) >>> else:
else: >>> global_step = min(global_step, decay_steps)
global_step = min(global_step, decay_steps) >>> decayed_learning_rate = (learning_rate - end_learning_rate) *
decayed_learning_rate = (learning_rate - end_learning_rate) * >>> (1 - global_step / decay_steps) ^ power +
(1 - global_step / decay_steps) ^ power + >>> end_learning_rate
end_learning_rate
```
Args: Args:
learning_rate: A scalar float32 value or a Variable. This learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number. decay_steps: A Python `int32` number.
end_learning_rate: A Python `float` number. end_learning_rate: A Python `float` number.
power: A Python `float` number power: A Python `float` number
...@@ -167,8 +150,7 @@ def polynomial_decay(learning_rate, ...@@ -167,8 +150,7 @@ def polynomial_decay(learning_rate,
Returns: Returns:
The decayed learning rate The decayed learning rate
""" """
if not isinstance(global_step, Variable): global_step = float_global_step()
raise ValueError("global_step is required for inverse_time_decay.")
with init_on_cpu(): with init_on_cpu():
if cycle: if cycle:
...@@ -193,27 +175,24 @@ def polynomial_decay(learning_rate, ...@@ -193,27 +175,24 @@ def polynomial_decay(learning_rate,
return decayed_lr return decayed_lr
def piecewise_decay(global_step, boundaries, values): def piecewise_decay(boundaries, values):
"""Applies piecewise decay to the initial learning rate. """Applies piecewise decay to the initial learning rate.
```python >>> boundaries = [10000, 20000]
boundaries = [10000, 20000] >>> values = [1.0, 0.5, 0.1]
values = [1.0, 0.5, 0.1] >>>
>>> if step < 10000:
if step < 10000: >>> learning_rate = 1.0
learning_rate = 1.0 >>> elif 10000 <= step < 20000:
elif step >= 10000 and step < 20000: >>> learning_rate = 0.5
learning_rate = 0.5 >>> else:
else: >>> learning_rate = 0.1
learning_rate = 0.1
```
""" """
if len(values) - len(boundaries) != 1: if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1") raise ValueError("len(values) - len(boundaries) should be 1")
if not isinstance(global_step, Variable): global_step = float_global_step()
raise ValueError("global_step is required for piecewise_decay.")
with init_on_cpu(): with init_on_cpu():
lr = layers.create_global_var( lr = layers.create_global_var(
......
...@@ -35,9 +35,8 @@ class Optimizer(object): ...@@ -35,9 +35,8 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
def __init__(self, learning_rate, global_step=None, regularization=None): def __init__(self, learning_rate, regularization=None):
assert learning_rate is not None assert learning_rate is not None
self._global_step = global_step
self.regularization = regularization self.regularization = regularization
self._global_learning_rate = learning_rate self._global_learning_rate = learning_rate
# Dictionary of accumulators. Some optimizer subclasses need to # Dictionary of accumulators. Some optimizer subclasses need to
...@@ -144,26 +143,6 @@ class Optimizer(object): ...@@ -144,26 +143,6 @@ class Optimizer(object):
format(name, param.name)) format(name, param.name))
return self._accumulators[name][param.name] return self._accumulators[name][param.name]
def _increment_global_step(self, block):
"""Increment the global step by 1 after every iteration
Args:
block: the block in which the loss variable is present
Returns:
list with global_step increment op as its only element
"""
assert isinstance(block, framework.Block)
assert self._global_step is not None
# create the increment op
increment_op = block.append_op(
type="increment",
inputs={"X": self._global_step},
outputs={"Out": self._global_step},
attrs={"step": 1.0})
return increment_op
def create_optimization_pass(self, def create_optimization_pass(self,
parameters_and_grads, parameters_and_grads,
loss, loss,
...@@ -210,8 +189,6 @@ class Optimizer(object): ...@@ -210,8 +189,6 @@ class Optimizer(object):
# FIXME: Need to fix this once we figure out how to handle dependencies # FIXME: Need to fix this once we figure out how to handle dependencies
self._finish_update(loss.block) self._finish_update(loss.block)
if self._global_step is not None:
self._increment_global_step(loss.block)
end = len(global_block.ops) end = len(global_block.ops)
return global_block.slice_ops(start, end) return global_block.slice_ops(start, end)
......
...@@ -168,16 +168,12 @@ def train(use_cuda, save_dirname=None): ...@@ -168,16 +168,12 @@ def train(use_cuda, save_dirname=None):
# TODO(qiao) # TODO(qiao)
# check other optimizers and check why out will be NAN # check other optimizers and check why out will be NAN
global_step = fluid.layers.create_global_var(
shape=[1], value=0, dtype='float32', force_cpu=True, persistable=True)
sgd_optimizer = fluid.optimizer.SGD( sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.learning_rate_decay.exponential_decay( learning_rate=fluid.learning_rate_decay.exponential_decay(
learning_rate=0.0001, learning_rate=0.0001,
global_step=global_step,
decay_steps=100000, decay_steps=100000,
decay_rate=0.5, decay_rate=0.5,
staircase=True), staircase=True))
global_step=global_step)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
# TODO(qiao) # TODO(qiao)
......
...@@ -28,7 +28,7 @@ def exponential_decay(learning_rate, ...@@ -28,7 +28,7 @@ def exponential_decay(learning_rate,
decay_steps, decay_steps,
decay_rate, decay_rate,
staircase=False): staircase=False):
exponent = float(global_step) / float(decay_steps) exponent = global_step / decay_steps
if staircase: if staircase:
exponent = math.floor(exponent) exponent = math.floor(exponent)
return learning_rate * decay_rate**exponent return learning_rate * decay_rate**exponent
...@@ -83,22 +83,25 @@ def piecewise_decay(global_step, boundaries, values): ...@@ -83,22 +83,25 @@ def piecewise_decay(global_step, boundaries, values):
class TestLearningRateDecay(unittest.TestCase): class TestLearningRateDecay(unittest.TestCase):
def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs): def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
global_step = layers.create_global_var( decayed_lr = fluid_decay_fn(**kwargs)
shape=[1], value=0.0, dtype='float32', persistable=True)
decayed_lr = fluid_decay_fn(global_step=global_step, **kwargs)
layers.increment(global_step, 1.0)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for step in range(10): for step in range(10):
step_val, lr_val = exe.run(fluid.default_main_program(), step_val, lr_val = exe.run(
feed=[], fluid.default_main_program(),
fetch_list=[global_step, decayed_lr]) feed=[],
python_decayed_lr = python_decay_fn(global_step=step, **kwargs) fetch_list=[fluid.layers.global_step_counter(), decayed_lr])
self.assertAlmostEqual(python_decayed_lr, lr_val[0]) python_decayed_lr = python_decay_fn(
global_step=float(step), **kwargs)
self.assertAlmostEqual(
python_decayed_lr,
lr_val[0],
msg='Failed fn is {0}, Python result is {1}, Fluid result is {2}'.
format(python_decay_fn.__name__,
str(python_decayed_lr), str(lr_val[0])))
def test_decay(self): def test_decay(self):
common_kwargs_true = { common_kwargs_true = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册