未验证 提交 ea60e644 编写于 作者: M mapingshuo 提交者: GitHub

correct the LookaheadOptimizer programDesc, test=develop (#25688)

上级 b5f8784c
...@@ -4884,6 +4884,7 @@ class LookaheadOptimizer(object): ...@@ -4884,6 +4884,7 @@ class LookaheadOptimizer(object):
inputs={"X": fast_var}, inputs={"X": fast_var},
outputs={"Out": slow_var}) outputs={"Out": slow_var})
with framework.program_guard(main_block.program, startup_program):
# Add Var k to main prog and startup prog # Add Var k to main prog and startup prog
k = layers.create_global_var( k = layers.create_global_var(
name="lookahead_k", name="lookahead_k",
...@@ -4910,9 +4911,11 @@ class LookaheadOptimizer(object): ...@@ -4910,9 +4911,11 @@ class LookaheadOptimizer(object):
layers.increment(x=step, value=1.0, in_place=True) layers.increment(x=step, value=1.0, in_place=True)
# lookahead # lookahead
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0) zero_var = layers.fill_constant(
shape=[1], dtype='float32', value=0.0)
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0) one_var = layers.fill_constant(
shape=[1], dtype='float32', value=1.0)
mod = layers.elementwise_mod(step, k) mod = layers.elementwise_mod(step, k)
with layers.control_flow.Switch() as switch: with layers.control_flow.Switch() as switch:
...@@ -4923,7 +4926,8 @@ class LookaheadOptimizer(object): ...@@ -4923,7 +4926,8 @@ class LookaheadOptimizer(object):
tmp_var = layers.elementwise_add( tmp_var = layers.elementwise_add(
layers.elementwise_mul(fast_var, alpha), layers.elementwise_mul(fast_var, alpha),
layers.elementwise_mul( layers.elementwise_mul(
slow_var, layers.elementwise_sub(one_var, alpha))) slow_var,
layers.elementwise_sub(one_var, alpha)))
layers.assign(input=tmp_var, output=slow_var) layers.assign(input=tmp_var, output=slow_var)
layers.assign(input=tmp_var, output=fast_var) layers.assign(input=tmp_var, output=fast_var)
with switch.default(): with switch.default():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册