提交 4a1bdf58 编写于 作者: D duanyanhui 提交者: cuicheng01

rm with_data_parallel

上级 6a55aac3
...@@ -153,12 +153,6 @@ def create_strategy(config): ...@@ -153,12 +153,6 @@ def create_strategy(config):
exec_strategy: exec strategy exec_strategy: exec strategy
""" """
build_strategy = paddle.static.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = (
10000
if 'AMP' in config and config.AMP.get("level", "O1") == "O2" else 10)
fuse_op = True if 'AMP' in config else False fuse_op = True if 'AMP' in config else False
...@@ -172,7 +166,7 @@ def create_strategy(config): ...@@ -172,7 +166,7 @@ def create_strategy(config):
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
build_strategy.enable_addto = enable_addto build_strategy.enable_addto = enable_addto
return build_strategy, exec_strategy return build_strategy
def dist_optimizer(config, optimizer): def dist_optimizer(config, optimizer):
...@@ -186,10 +180,9 @@ def dist_optimizer(config, optimizer): ...@@ -186,10 +180,9 @@ def dist_optimizer(config, optimizer):
Returns: Returns:
optimizer: a distributed optimizer optimizer: a distributed optimizer
""" """
build_strategy, exec_strategy = create_strategy(config) build_strategy = create_strategy(config)
dist_strategy = DistributedStrategy() dist_strategy = DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.build_strategy = build_strategy dist_strategy.build_strategy = build_strategy
dist_strategy.nccl_comm_num = 1 dist_strategy.nccl_comm_num = 1
...@@ -298,14 +291,10 @@ def compile(config, program, loss_name=None, share_prog=None): ...@@ -298,14 +291,10 @@ def compile(config, program, loss_name=None, share_prog=None):
Returns: Returns:
compiled_program(): a compiled program compiled_program(): a compiled program
""" """
build_strategy, exec_strategy = create_strategy(config) build_strategy = create_strategy(config)
compiled_program = paddle.static.CompiledProgram( compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel( program, build_strategy=build_strategy)
share_vars_from=share_prog,
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compiled_program return compiled_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册