未验证 提交 62fd1927 编写于 作者: W WangXi 提交者: GitHub

optimizer fleet distributed strategy (#500)

上级 7b15ef1a
......@@ -288,6 +288,60 @@ def create_optimizer(config):
opt = OptimizerBuilder(config, **opt_config)
return opt(lr), lr
def create_strategy(config):
"""
Create build strategy and exec strategy.
Args:
config(dict): config
Returns:
build_strategy: build strategy
exec_strategy: exec strategy
"""
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
'use_pure_fp16', False) else 10
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
False)
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
enable_addto = config.get('enable_addto', fuse_op)
try:
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
except Exception as e:
logger.info(
"PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse batch_norm and activation_op.")
try:
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse elewise_add_act and activation_op.")
try:
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle 2.0-rc or higher is "
"required when you want to enable fuse_bn_add_act_ops strategy.")
try:
build_strategy.enable_addto = enable_addto
except Exception as e:
logger.info("PaddlePaddle 2.0-rc or higher is "
"required when you want to enable addto strategy.")
return build_strategy, exec_strategy
def dist_optimizer(config, optimizer):
"""
......@@ -300,14 +354,15 @@ def dist_optimizer(config, optimizer):
Returns:
optimizer: a distributed optimizer
"""
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 3
exec_strategy.num_iteration_per_drop_scope = 10
build_strategy, exec_strategy = create_strategy(config)
dist_strategy = DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.build_strategy = build_strategy
dist_strategy.nccl_comm_num = 1
dist_strategy.fuse_all_reduce_ops = True
dist_strategy.execution_strategy = exec_strategy
dist_strategy.fuse_grad_size_in_MB = 16
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
return optimizer
......@@ -399,46 +454,7 @@ def compile(config, program, loss_name=None, share_prog=None):
Returns:
compiled_program(): a compiled program
"""
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
'use_pure_fp16', False) else 10
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
False)
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
enable_addto = config.get('enable_addto', fuse_op)
try:
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
except Exception as e:
logger.info(
"PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse batch_norm and activation_op.")
try:
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse elewise_add_act and activation_op.")
try:
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle 2.0-rc or higher is "
"required when you want to enable fuse_bn_add_act_ops strategy.")
try:
build_strategy.enable_addto = enable_addto
except Exception as e:
logger.info("PaddlePaddle 2.0-rc or higher is "
"required when you want to enable addto strategy.")
build_strategy, exec_strategy = create_strategy(config)
compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册