提交 9db107da 编写于 作者: Y Yu Yang

Renamed and add comments

上级 2af9aac2
...@@ -69,7 +69,7 @@ __all__ = [ ...@@ -69,7 +69,7 @@ __all__ = [
'softmax_with_cross_entropy', 'softmax_with_cross_entropy',
'smooth_l1', 'smooth_l1',
'one_hot', 'one_hot',
'global_step_counter', 'autoincreased_step_counter',
] ]
...@@ -3253,23 +3253,32 @@ def one_hot(input, depth): ...@@ -3253,23 +3253,32 @@ def one_hot(input, depth):
return one_hot_out return one_hot_out
def global_step_counter(): def autoincreased_step_counter(counter_name=None, begin=1, step=1):
""" """
NOTE: The counter will be automatically increased by 1 every mini-batch
Return the run counter of the main program, which is started with 1. Return the run counter of the main program, which is started with 1.
Args:
counter_name(str): The counter name, default is '@STEP_COUNTER@'.
begin(int): The first value of this counter.
step(int): The increment step between each execution.
Returns(Variable): The global run counter. Returns(Variable): The global run counter.
""" """
helper = LayerHelper('global_step_counter') helper = LayerHelper('global_step_counter')
counter_name = '@STEP_COUNTER@' if counter_name is None:
counter_name = '@STEP_COUNTER@'
counter, is_new_var = helper.create_or_get_global_variable( counter, is_new_var = helper.create_or_get_global_variable(
name=counter_name, dtype='int64', shape=[1], persistable=True) name=counter_name, dtype='int64', shape=[1], persistable=True)
if is_new_var: if is_new_var:
helper.set_variable_initializer( helper.set_variable_initializer(
counter, initializer=Constant( counter, initializer=Constant(
value=0, force_cpu=True)) value=begin - 1, force_cpu=True))
helper.main_program.global_block().prepend_op( helper.main_program.global_block().prepend_op(
type='increment', type='increment',
inputs={'X': [counter]}, inputs={'X': [counter]},
outputs={'Out': [counter]}) outputs={'Out': [counter]},
attrs={'step': float(step)})
counter.stop_gradient = True counter.stop_gradient = True
return counter return counter
...@@ -32,7 +32,8 @@ strategy according to this module. ...@@ -32,7 +32,8 @@ strategy according to this module.
def float_global_step(): def float_global_step():
# the first global step is zero in learning rate decay # the first global step is zero in learning rate decay
global_step = layers.global_step_counter() - 1 global_step = layers.autoincreased_step_counter(
counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
global_step = layers.cast(global_step, 'float32') global_step = layers.cast(global_step, 'float32')
return global_step return global_step
......
...@@ -93,7 +93,9 @@ class TestLearningRateDecay(unittest.TestCase): ...@@ -93,7 +93,9 @@ class TestLearningRateDecay(unittest.TestCase):
step_val, lr_val = exe.run( step_val, lr_val = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed=[], feed=[],
fetch_list=[fluid.layers.global_step_counter(), decayed_lr]) fetch_list=[
fluid.layers.autoincreased_step_counter(), decayed_lr
])
python_decayed_lr = python_decay_fn( python_decayed_lr = python_decay_fn(
global_step=float(step), **kwargs) global_step=float(step), **kwargs)
self.assertAlmostEqual( self.assertAlmostEqual(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册