未验证 提交 f3ce7dda 编写于 作者: W WangXi 提交者: GitHub

Close fuse when use dgc & move DGC strategy from PE to compiler, test=develop (#22917)

上级 cec3cfba
...@@ -356,6 +356,16 @@ class CompiledProgram(object): ...@@ -356,6 +356,16 @@ class CompiledProgram(object):
if self._build_strategy.sync_batch_norm: if self._build_strategy.sync_batch_norm:
self._build_strategy.enable_sequential_execution = True self._build_strategy.enable_sequential_execution = True
if self._program is not None and self._program._enable_dgc:
assert use_cuda, "DGC only used under cuda"
assert self._build_strategy.num_trainers * len(
places) > 1, "DGC is not useful for single card training"
assert self._build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "DGC \
only used for AllReduce BuildStrategy"
# DGC doesn't support fuse for now, close fuse.
self._build_strategy.fuse_all_reduce_ops = False
self._persistable_vars = [] self._persistable_vars = []
for node in self._graph.nodes(): for node in self._graph.nodes():
if node.is_var() and node.var() is not None and node.var().persistable() and \ if node.is_var() and node.var() is not None and node.var().persistable() and \
......
...@@ -175,15 +175,6 @@ class ParallelExecutor(object): ...@@ -175,15 +175,6 @@ class ParallelExecutor(object):
) if use_cuda else framework.cpu_places() ) if use_cuda else framework.cpu_places()
self._scope = scope if scope is not None else executor.global_scope() self._scope = scope if scope is not None else executor.global_scope()
if main_program is not None and main_program._enable_dgc:
assert build_strategy.num_trainers > 1, "dgc is not useful when num_trainers <= 1"
assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "dgc \
only used for allreduce"
assert build_strategy.num_trainers * len(
self._places) > 1, "dgc is not useful for single card training"
assert use_cuda, "dgc only used under cuda"
main_program = main_program if main_program is not None \ main_program = main_program if main_program is not None \
else framework.default_main_program() else framework.default_main_program()
......
...@@ -334,10 +334,6 @@ class TestDistRunnerBase(object): ...@@ -334,10 +334,6 @@ class TestDistRunnerBase(object):
build_stra.num_trainers = 1 build_stra.num_trainers = 1
build_stra.trainer_id = 0 build_stra.trainer_id = 0
if args.use_dgc:
# fuse_all_reduce_ops require that gradients should not be sparse types
build_stra.fuse_all_reduce_ops = False
print_to_err(type(self).__name__, "begin to compile with data parallel") print_to_err(type(self).__name__, "begin to compile with data parallel")
binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
loss_name=avg_cost.name, loss_name=avg_cost.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册