提交 9db107da 编写于 作者: Y Yu Yang

Renamed and add comments

上级 2af9aac2
......@@ -69,7 +69,7 @@ __all__ = [
'softmax_with_cross_entropy',
'smooth_l1',
'one_hot',
'global_step_counter',
'autoincreased_step_counter',
]
......@@ -3253,23 +3253,32 @@ def one_hot(input, depth):
return one_hot_out
def global_step_counter():
def autoincreased_step_counter(counter_name=None, begin=1, step=1):
"""
NOTE: The counter will be automatically increased by 1 every mini-batch
Return the run counter of the main program, which is started with 1.
Args:
counter_name(str): The counter name, default is '@STEP_COUNTER@'.
begin(int): The first value of this counter.
step(int): The increment step between each execution.
Returns(Variable): The global run counter.
"""
helper = LayerHelper('global_step_counter')
counter_name = '@STEP_COUNTER@'
if counter_name is None:
counter_name = '@STEP_COUNTER@'
counter, is_new_var = helper.create_or_get_global_variable(
name=counter_name, dtype='int64', shape=[1], persistable=True)
if is_new_var:
helper.set_variable_initializer(
counter, initializer=Constant(
value=0, force_cpu=True))
value=begin - 1, force_cpu=True))
helper.main_program.global_block().prepend_op(
type='increment',
inputs={'X': [counter]},
outputs={'Out': [counter]})
outputs={'Out': [counter]},
attrs={'step': float(step)})
counter.stop_gradient = True
return counter
......@@ -32,7 +32,8 @@ strategy according to this module.
def float_global_step():
# the first global step is zero in learning rate decay
global_step = layers.global_step_counter() - 1
global_step = layers.autoincreased_step_counter(
counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)
global_step = layers.cast(global_step, 'float32')
return global_step
......
......@@ -93,7 +93,9 @@ class TestLearningRateDecay(unittest.TestCase):
step_val, lr_val = exe.run(
fluid.default_main_program(),
feed=[],
fetch_list=[fluid.layers.global_step_counter(), decayed_lr])
fetch_list=[
fluid.layers.autoincreased_step_counter(), decayed_lr
])
python_decayed_lr = python_decay_fn(
global_step=float(step), **kwargs)
self.assertAlmostEqual(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册