提交 4a1bdf58 编写于 作者: D duanyanhui 提交者: cuicheng01

rm with_data_parallel

上级 6a55aac3
......@@ -153,12 +153,6 @@ def create_strategy(config):
exec_strategy: exec strategy
"""
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = (
10000
if 'AMP' in config and config.AMP.get("level", "O1") == "O2" else 10)
fuse_op = True if 'AMP' in config else False
......@@ -172,7 +166,7 @@ def create_strategy(config):
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
build_strategy.enable_addto = enable_addto
return build_strategy, exec_strategy
return build_strategy
def dist_optimizer(config, optimizer):
......@@ -186,10 +180,9 @@ def dist_optimizer(config, optimizer):
Returns:
optimizer: a distributed optimizer
"""
build_strategy, exec_strategy = create_strategy(config)
build_strategy = create_strategy(config)
dist_strategy = DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.build_strategy = build_strategy
dist_strategy.nccl_comm_num = 1
......@@ -298,14 +291,10 @@ def compile(config, program, loss_name=None, share_prog=None):
Returns:
compiled_program(): a compiled program
"""
build_strategy, exec_strategy = create_strategy(config)
build_strategy = create_strategy(config)
compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel(
share_vars_from=share_prog,
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
program, build_strategy=build_strategy)
return compiled_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册