未验证 提交 7ae10900 编写于 作者: M mapingshuo 提交者: GitHub

fix slow var initialize, test=develop (#26516)

上级 5407e327
...@@ -1141,7 +1141,7 @@ class MomentumOptimizer(Optimizer): ...@@ -1141,7 +1141,7 @@ class MomentumOptimizer(Optimizer):
class DGCMomentumOptimizer(Optimizer): class DGCMomentumOptimizer(Optimizer):
""" """
:api_attr: Static Graph :api_attr: Static Graph
DGC (Deep Gradient Compression) Momentum Optimizer. Original paper is https://arxiv.org/abs/1712.01887 DGC (Deep Gradient Compression) Momentum Optimizer. Original paper is https://arxiv.org/abs/1712.01887
...@@ -3067,7 +3067,7 @@ Lamb = LambOptimizer ...@@ -3067,7 +3067,7 @@ Lamb = LambOptimizer
class ModelAverage(Optimizer): class ModelAverage(Optimizer):
""" """
:api_attr: Static Graph :api_attr: Static Graph
The ModelAverage optimizer accumulates specific continuous historical parameters The ModelAverage optimizer accumulates specific continuous historical parameters
during training. The accumulated historical range can be controlled by the passed during training. The accumulated historical range can be controlled by the passed
...@@ -3376,7 +3376,7 @@ class ModelAverage(Optimizer): ...@@ -3376,7 +3376,7 @@ class ModelAverage(Optimizer):
class ExponentialMovingAverage(object): class ExponentialMovingAverage(object):
""" """
:api_attr: Static Graph :api_attr: Static Graph
Compute the moving average of parameters with exponential decay. Compute the moving average of parameters with exponential decay.
Given a parameter :math:`\\theta`, its exponential moving average (EMA) Given a parameter :math:`\\theta`, its exponential moving average (EMA)
...@@ -3626,7 +3626,7 @@ class ExponentialMovingAverage(object): ...@@ -3626,7 +3626,7 @@ class ExponentialMovingAverage(object):
class PipelineOptimizer(object): class PipelineOptimizer(object):
""" """
:api_attr: Static Graph :api_attr: Static Graph
Pipeline Optimizer: Make a program to run as pipeline, that is splitting a Pipeline Optimizer: Make a program to run as pipeline, that is splitting a
program into multiple sections (sub-programs) and each section run on a program into multiple sections (sub-programs) and each section run on a
...@@ -4477,7 +4477,7 @@ class PipelineOptimizer(object): ...@@ -4477,7 +4477,7 @@ class PipelineOptimizer(object):
class RecomputeOptimizer(Optimizer): class RecomputeOptimizer(Optimizer):
""" """
:api_attr: Static Graph :api_attr: Static Graph
Recompute Optimizer Wrapper Recompute Optimizer Wrapper
...@@ -4562,7 +4562,7 @@ class RecomputeOptimizer(Optimizer): ...@@ -4562,7 +4562,7 @@ class RecomputeOptimizer(Optimizer):
def load(self, stat_dict): def load(self, stat_dict):
""" """
:api_attr: Static Graph :api_attr: Static Graph
load function is not supported by Recompute Optimizer for now. load function is not supported by Recompute Optimizer for now.
:return: None :return: None
...@@ -4786,7 +4786,7 @@ class RecomputeOptimizer(Optimizer): ...@@ -4786,7 +4786,7 @@ class RecomputeOptimizer(Optimizer):
class LookaheadOptimizer(object): class LookaheadOptimizer(object):
""" """
:api_attr: Static Graph :api_attr: Static Graph
This implements the Lookahead optimizer of the This implements the Lookahead optimizer of the
paper : https://arxiv.org/abs/1907.08610. paper : https://arxiv.org/abs/1907.08610.
...@@ -4929,6 +4929,11 @@ class LookaheadOptimizer(object): ...@@ -4929,6 +4929,11 @@ class LookaheadOptimizer(object):
mod = layers.elementwise_mod(step, k) mod = layers.elementwise_mod(step, k)
with layers.control_flow.Switch() as switch: with layers.control_flow.Switch() as switch:
with switch.case(step == one_var):
for param_name in params:
fast_var = main_block.var(param_name)
slow_var = param_to_slow[param_name]
layers.assign(input=fast_var, output=slow_var)
with switch.case(mod == zero_var): with switch.case(mod == zero_var):
for param_name in params: for param_name in params:
fast_var = main_block.var(param_name) fast_var = main_block.var(param_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册