未验证 提交 d3fbede9 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8564 from reyoung/feature/add_global_step

Add global_step in nn.py
......@@ -330,9 +330,28 @@ class LayerHelper(object):
return self.main_program.current_block().create_var(*args, **kwargs)
def create_global_variable(self, persistable=False, *args, **kwargs):
"""
create global variable, note that there is no initializer for this global variable.
Args:
persistable(bool): True if it is a checkpoint value.
*args: See create_var's documentation
**kwargs: See create_var's documentation
Returns(Variable): the created variable.
"""
return self.main_program.global_block().create_var(
*args, persistable=persistable, **kwargs)
def create_or_get_global_variable(self, name, *args, **kwargs):
"""
Creates a global variable if not exists and returns the variable and
a boolean flag which is true when it is a new variable.
"""
if self.main_program.global_block().has_var(name):
return self.main_program.global_block().var(name), False
else:
return self.create_global_variable(name=name, *args, **kwargs), True
def set_variable_initializer(self, var, initializer):
assert isinstance(var, Variable)
self.startup_program.global_block().create_var(
......
......@@ -70,6 +70,7 @@ __all__ = [
'softmax_with_cross_entropy',
'smooth_l1',
'one_hot',
'autoincreased_step_counter',
]
......@@ -3236,3 +3237,34 @@ def one_hot(input, depth):
attrs={'depth': depth},
outputs={'Out': one_hot_out})
return one_hot_out
def autoincreased_step_counter(counter_name=None, begin=1, step=1):
"""
NOTE: The counter will be automatically increased by 1 every mini-batch
Return the run counter of the main program, which is started with 1.
Args:
counter_name(str): The counter name, default is '@STEP_COUNTER@'.
begin(int): The first value of this counter.
step(int): The increment step between each execution.
Returns(Variable): The global run counter.
"""
helper = LayerHelper('global_step_counter')
if counter_name is None:
counter_name = '@STEP_COUNTER@'
counter, is_new_var = helper.create_or_get_global_variable(
name=counter_name, dtype='int64', shape=[1], persistable=True)
if is_new_var:
helper.set_variable_initializer(
counter, initializer=Constant(
value=begin - 1, force_cpu=True))
helper.main_program.global_block().prepend_op(
type='increment',
inputs={'X': [counter]},
outputs={'Out': [counter]},
attrs={'step': float(step)})
counter.stop_gradient = True
return counter
......@@ -13,7 +13,6 @@
# limitations under the License.
import layers
from framework import Variable
from initializer import init_on_cpu
__all__ = [
......@@ -30,11 +29,15 @@ strategy according to this module.
"""
def exponential_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
def _decay_step_counter():
# the first global step is zero in learning rate decay
global_step = layers.autoincreased_step_counter(
counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
global_step = layers.cast(global_step, 'float32')
return global_step
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""Applies exponential decay to the learning rate.
```python
......@@ -44,7 +47,6 @@ def exponential_decay(learning_rate,
Args:
learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number.
decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps.
......@@ -52,8 +54,7 @@ def exponential_decay(learning_rate,
Returns:
The decayed learning rate
"""
if not isinstance(global_step, Variable):
raise ValueError("global_step is required for exponential_decay.")
global_step = _decay_step_counter()
with init_on_cpu():
# update learning_rate
......@@ -65,23 +66,17 @@ def exponential_decay(learning_rate,
return decayed_lr
def natural_exp_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""Applies natural exponential decay to the initial learning rate.
```python
if not staircase:
decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
else:
decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
```
>>> if not staircase:
>>> decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
>>> else:
>>> decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
Args:
learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number.
decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps.
......@@ -89,8 +84,7 @@ def natural_exp_decay(learning_rate,
Returns:
The decayed learning rate
"""
if not isinstance(global_step, Variable):
raise ValueError("global_step is required for natural_exp_decay.")
global_step = _decay_step_counter()
with init_on_cpu():
div_res = global_step / decay_steps
......@@ -101,23 +95,17 @@ def natural_exp_decay(learning_rate,
return decayed_lr
def inverse_time_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""Applies inverse time decay to the initial learning rate.
```python
if staircase:
decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
else:
decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)
```
>>> if staircase:
>>> decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
>>> else:
>>> decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)
Args:
learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training
global_step: A Variable that record the training step.
will be the initial learning rate during training.
decay_steps: A Python `int32` number.
decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps.
......@@ -125,8 +113,7 @@ def inverse_time_decay(learning_rate,
Returns:
The decayed learning rate
"""
if not isinstance(global_step, Variable):
raise ValueError("global_step is required for inverse_time_decay.")
global_step = _decay_step_counter()
with init_on_cpu():
div_res = global_step / decay_steps
......@@ -139,26 +126,22 @@ def inverse_time_decay(learning_rate,
def polynomial_decay(learning_rate,
global_step,
decay_steps,
end_learning_rate=0.0001,
power=1.0,
cycle=False):
"""Applies polynomial decay to the initial learning rate.
```python
if cycle:
decay_steps = decay_steps * ceil(global_step / decay_steps)
else:
global_step = min(global_step, decay_steps)
decayed_learning_rate = (learning_rate - end_learning_rate) *
(1 - global_step / decay_steps) ^ power +
end_learning_rate
```
>>> if cycle:
>>> decay_steps = decay_steps * ceil(global_step / decay_steps)
>>> else:
>>> global_step = min(global_step, decay_steps)
>>> decayed_learning_rate = (learning_rate - end_learning_rate) *
>>> (1 - global_step / decay_steps) ^ power +
>>> end_learning_rate
Args:
learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number.
end_learning_rate: A Python `float` number.
power: A Python `float` number
......@@ -167,8 +150,7 @@ def polynomial_decay(learning_rate,
Returns:
The decayed learning rate
"""
if not isinstance(global_step, Variable):
raise ValueError("global_step is required for inverse_time_decay.")
global_step = _decay_step_counter()
with init_on_cpu():
if cycle:
......@@ -193,27 +175,24 @@ def polynomial_decay(learning_rate,
return decayed_lr
def piecewise_decay(global_step, boundaries, values):
def piecewise_decay(boundaries, values):
"""Applies piecewise decay to the initial learning rate.
```python
boundaries = [10000, 20000]
values = [1.0, 0.5, 0.1]
if step < 10000:
learning_rate = 1.0
elif step >= 10000 and step < 20000:
learning_rate = 0.5
else:
learning_rate = 0.1
```
>>> boundaries = [10000, 20000]
>>> values = [1.0, 0.5, 0.1]
>>>
>>> if step < 10000:
>>> learning_rate = 1.0
>>> elif 10000 <= step < 20000:
>>> learning_rate = 0.5
>>> else:
>>> learning_rate = 0.1
"""
if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1")
if not isinstance(global_step, Variable):
raise ValueError("global_step is required for piecewise_decay.")
global_step = _decay_step_counter()
with init_on_cpu():
lr = layers.create_global_var(
......
......@@ -35,11 +35,10 @@ class Optimizer(object):
but need to use one of it's implementation.
"""
def __init__(self, learning_rate, global_step=None, regularization=None):
def __init__(self, learning_rate, regularization=None):
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise TypeError("learning rate should be float or Variable")
self._global_step = global_step
self.regularization = regularization
self._learning_rate = learning_rate
# each program should have a independent learning rate
......@@ -159,26 +158,6 @@ class Optimizer(object):
format(name, param.name))
return self._accumulators[name][param.name]
def _increment_global_step(self, block):
"""Increment the global step by 1 after every iteration
Args:
block: the block in which the loss variable is present
Returns:
list with global_step increment op as its only element
"""
assert isinstance(block, framework.Block)
assert self._global_step is not None
# create the increment op
increment_op = block.append_op(
type="increment",
inputs={"X": self._global_step},
outputs={"Out": self._global_step},
attrs={"step": 1.0})
return increment_op
def create_optimization_pass(self,
parameters_and_grads,
loss,
......@@ -225,8 +204,6 @@ class Optimizer(object):
# FIXME: Need to fix this once we figure out how to handle dependencies
self._finish_update(loss.block)
if self._global_step is not None:
self._increment_global_step(loss.block)
end = len(global_block.ops)
return global_block.slice_ops(start, end)
......
......@@ -169,16 +169,12 @@ def train(use_cuda, save_dirname=None, is_local=True):
# TODO(qiao)
# check other optimizers and check why out will be NAN
global_step = fluid.layers.create_global_var(
shape=[1], value=0, dtype='float32', force_cpu=True, persistable=True)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.learning_rate_decay.exponential_decay(
learning_rate=0.0001,
global_step=global_step,
decay_steps=100000,
decay_rate=0.5,
staircase=True),
global_step=global_step)
staircase=True))
optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
# TODO(qiao)
......
......@@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import math
import copy
import math
import unittest
import paddle.fluid.framework as framework
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.framework as framework
import paddle.fluid.learning_rate_decay as lr_decay
......@@ -28,7 +26,7 @@ def exponential_decay(learning_rate,
decay_steps,
decay_rate,
staircase=False):
exponent = float(global_step) / float(decay_steps)
exponent = global_step / decay_steps
if staircase:
exponent = math.floor(exponent)
return learning_rate * decay_rate**exponent
......@@ -83,22 +81,24 @@ def piecewise_decay(global_step, boundaries, values):
class TestLearningRateDecay(unittest.TestCase):
def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
global_step = layers.create_global_var(
shape=[1], value=0.0, dtype='float32', persistable=True)
decayed_lr = fluid_decay_fn(global_step=global_step, **kwargs)
layers.increment(global_step, 1.0)
decayed_lr = fluid_decay_fn(**kwargs)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for step in range(10):
step_val, lr_val = exe.run(fluid.default_main_program(),
feed=[],
fetch_list=[global_step, decayed_lr])
python_decayed_lr = python_decay_fn(global_step=step, **kwargs)
self.assertAlmostEqual(python_decayed_lr, lr_val[0])
lr_val, = exe.run(fluid.default_main_program(),
feed=[],
fetch_list=[decayed_lr])
python_decayed_lr = python_decay_fn(
global_step=float(step), **kwargs)
self.assertAlmostEqual(
python_decayed_lr,
lr_val[0],
msg='Failed fn is {0}, Python result is {1}, Fluid result is {2}'.
format(python_decay_fn.__name__,
str(python_decayed_lr), str(lr_val[0])))
def test_decay(self):
common_kwargs_true = {
......
......@@ -46,43 +46,6 @@ class TestOptimizer(unittest.TestCase):
self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "sgd"])
def test_sgd_optimizer_with_global_step(self):
init_program = framework.Program()
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
global_step = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="step")
learning_rate = 0.01
sgd_optimizer = optimizer.SGDOptimizer(
learning_rate=learning_rate, global_step=global_step)
opts, _ = sgd_optimizer.minimize(mean_out, init_program)
self.assertEqual(len(opts), 4)
self.assertEqual(
[op.type for op in opts],
["fill_constant", "elementwise_mul", "sgd", "increment"])
# Check init_program
init_ops = init_program.global_block().ops
self.assertEqual(len(init_ops), 1)
self.assertEqual(init_ops[0].type, "fill_constant")
self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate)
class TestMomentumOptimizer(unittest.TestCase):
class MockMomentum(optimizer.MomentumOptimizer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册