未验证 提交 4ec51e02 编写于 作者: D Dong Daxiang 提交者: GitHub

【paddle.fleet】Clear disable (#26334)

* add check approval
test=develop
上级 3b2c580a
...@@ -76,6 +76,18 @@ class StrategyCompiler(StrategyCompilerBase): ...@@ -76,6 +76,18 @@ class StrategyCompiler(StrategyCompilerBase):
opt._disable_strategy(valid_strategy) opt._disable_strategy(valid_strategy)
return valid_strategy return valid_strategy
"""
Meta Optimizer Type A: rewrite forward, backward. e.g. recompute, async, sync, pipeline.
results will be splitted in async, sync, pipeline
Meta Optimizer Type B: rewrite forward,
e.g. AMP and the corresponding backward is generated by rewritten forward
Meta Opitmizer Type B: rewrite backward. e.g. gradient fusion
Meta Optimizer Type D: rewrite optimize. e.g. lars, lamb, localsgd, gradient merge, dgc
Meta Optimizer Type E: only transpile to Graph structure for runtime,
currently, grad fusion and kernel fusion, sync batch-norm included.
we will remove grad fusion and sync batch-norm
"""
def generate_optimizer(self, loss, role_maker, optimizer, def generate_optimizer(self, loss, role_maker, optimizer,
user_defined_strategy, meta_optimizer_list, user_defined_strategy, meta_optimizer_list,
graph_optimizer_list): graph_optimizer_list):
......
...@@ -37,6 +37,7 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -37,6 +37,7 @@ class AMPOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
dist_strategy.amp = False dist_strategy.amp = False
dist_strategy.amp_configs = {}
def minimize_impl(self, def minimize_impl(self,
loss, loss,
......
...@@ -33,6 +33,9 @@ class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer): ...@@ -33,6 +33,9 @@ class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer):
return True return True
def _disable_strategy(self, dist_strategy):
dist_strategy.a_sync_configs = {}
def _is_graph_out(self): def _is_graph_out(self):
return True return True
......
...@@ -139,4 +139,4 @@ class AsyncMetaOptimizer(MetaOptimizerBase): ...@@ -139,4 +139,4 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
return None, None return None, None
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
self.user_defined_strategy.a_sync_configs["k_steps"] = -1 self.user_defined_strategy.a_sync_configs = {}
...@@ -68,11 +68,7 @@ class DGCOptimizer(MetaOptimizerBase): ...@@ -68,11 +68,7 @@ class DGCOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
dist_strategy.dgc = False dist_strategy.dgc = False
dist_strategy.dgc_configs = { dist_strategy.dgc_configs = {}
'rampup_begin_step': 0,
'rampup_step': 1,
'sparsity': [0.999]
}
def backward(self, def backward(self,
loss, loss,
......
...@@ -40,7 +40,7 @@ class GradientMergeOptimizer(MetaOptimizerBase): ...@@ -40,7 +40,7 @@ class GradientMergeOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
dist_strategy.gradient_merge = False dist_strategy.gradient_merge = False
dist_strategy.gradient_merge_configs = {"k_steps": 1, "avg": True} dist_strategy.gradient_merge_configs = {}
def minimize_impl(self, def minimize_impl(self,
loss, loss,
......
...@@ -74,10 +74,7 @@ class LambOptimizer(MetaOptimizerBase): ...@@ -74,10 +74,7 @@ class LambOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
dist_strategy.lamb = False dist_strategy.lamb = False
dist_strategy.lamb_configs = { dist_strategy.lamb_configs = {}
'lamb_weight_decay': 0.01,
'exclude_from_weight_decay': [],
}
def backward(self, def backward(self,
loss, loss,
......
...@@ -58,10 +58,7 @@ class LarsOptimizer(MetaOptimizerBase): ...@@ -58,10 +58,7 @@ class LarsOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
dist_strategy.lars = False dist_strategy.lars = False
dist_strategy.lars_configs = { dist_strategy.lars_configs = {}
'lars_coeff': 0.001,
'lars_weight_decay': 0.0005,
}
def backward(self, def backward(self,
loss, loss,
......
...@@ -39,7 +39,7 @@ class LocalSGDOptimizer(MetaOptimizerBase): ...@@ -39,7 +39,7 @@ class LocalSGDOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
dist_strategy.localsgd = False dist_strategy.localsgd = False
dist_strategy.localsgd_configs = {'k_steps': 1} dist_strategy.localsgd_configs = {}
def snapshot_name(self, param_name): def snapshot_name(self, param_name):
return param_name + self.snapshot_key return param_name + self.snapshot_key
......
...@@ -38,6 +38,7 @@ class MetaOptimizerBase(object): ...@@ -38,6 +38,7 @@ class MetaOptimizerBase(object):
def _can_update(self, optimizer): def _can_update(self, optimizer):
if str(optimizer.__class__.__name__) in self.meta_optimizers_white_list: if str(optimizer.__class__.__name__) in self.meta_optimizers_white_list:
return True return True
return False
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
raise NotImplementedError("you should implement disable strategy in {}". raise NotImplementedError("you should implement disable strategy in {}".
......
...@@ -110,7 +110,7 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -110,7 +110,7 @@ class PipelineOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
dist_strategy.pipeline = False dist_strategy.pipeline = False
dist_strategy.pipeline_configs = {"micro_batch": 1} dist_strategy.pipeline_configs = {}
def minimize_impl(self, def minimize_impl(self,
loss, loss,
......
...@@ -42,7 +42,7 @@ class RecomputeOptimizer(MetaOptimizerBase): ...@@ -42,7 +42,7 @@ class RecomputeOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
dist_strategy.recompute = False dist_strategy.recompute = False
dist_strategy.recompute_configs = {"checkpoints": []} dist_strategy.recompute_configs = {}
def backward(self, def backward(self,
loss, loss,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册