未验证 提交 7ae10900 编写于 作者: M mapingshuo 提交者: GitHub

fix slow var initialize, test=develop (#26516)

上级 5407e327
......@@ -1141,7 +1141,7 @@ class MomentumOptimizer(Optimizer):
class DGCMomentumOptimizer(Optimizer):
"""
:api_attr: Static Graph
:api_attr: Static Graph
DGC (Deep Gradient Compression) Momentum Optimizer. Original paper is https://arxiv.org/abs/1712.01887
......@@ -3067,7 +3067,7 @@ Lamb = LambOptimizer
class ModelAverage(Optimizer):
"""
:api_attr: Static Graph
:api_attr: Static Graph
The ModelAverage optimizer accumulates specific continuous historical parameters
during training. The accumulated historical range can be controlled by the passed
......@@ -3376,7 +3376,7 @@ class ModelAverage(Optimizer):
class ExponentialMovingAverage(object):
"""
:api_attr: Static Graph
:api_attr: Static Graph
Compute the moving average of parameters with exponential decay.
Given a parameter :math:`\\theta`, its exponential moving average (EMA)
......@@ -3626,7 +3626,7 @@ class ExponentialMovingAverage(object):
class PipelineOptimizer(object):
"""
:api_attr: Static Graph
:api_attr: Static Graph
Pipeline Optimizer: Make a program to run as pipeline, that is splitting a
program into multiple sections (sub-programs) and each section run on a
......@@ -4477,7 +4477,7 @@ class PipelineOptimizer(object):
class RecomputeOptimizer(Optimizer):
"""
:api_attr: Static Graph
:api_attr: Static Graph
Recompute Optimizer Wrapper
......@@ -4562,7 +4562,7 @@ class RecomputeOptimizer(Optimizer):
def load(self, stat_dict):
"""
:api_attr: Static Graph
:api_attr: Static Graph
load function is not supported by Recompute Optimizer for now.
:return: None
......@@ -4786,7 +4786,7 @@ class RecomputeOptimizer(Optimizer):
class LookaheadOptimizer(object):
"""
:api_attr: Static Graph
:api_attr: Static Graph
This implements the Lookahead optimizer of the
paper : https://arxiv.org/abs/1907.08610.
......@@ -4929,6 +4929,11 @@ class LookaheadOptimizer(object):
mod = layers.elementwise_mod(step, k)
with layers.control_flow.Switch() as switch:
with switch.case(step == one_var):
for param_name in params:
fast_var = main_block.var(param_name)
slow_var = param_to_slow[param_name]
layers.assign(input=fast_var, output=slow_var)
with switch.case(mod == zero_var):
for param_name in params:
fast_var = main_block.var(param_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册