未验证 提交 920d998f 编写于 作者: D Dong Daxiang 提交者: GitHub

add more settings for distributed strategy (#25685)

* add more settings for distributed strategy
Basically, DistributedStrategy has several parts of configurations:
- BuildStrategy: the same as paddle.fluid.BuildStrategy, but the distributed arguments are moved out of BuildStrategy
- ExecutionStrategy: the same as paddle.fluid.ExecutionStrategy
- collective communication configs: nccl_comm_num, hierarchical allreduce and so on
- distributed algorithms: async_update(mainly used in PS), lars, lamb and so on
上级 1aaa26f1
......@@ -22,61 +22,87 @@ enum Mode {
HETER = 4; // support XPU and GPU computing server
}
message DistributedStrategy {
optional Mode mode = 1 [ default = COLLECTIVE ]; // just for serialization
// collective training strategy
optional bool amp = 2 [ default = false ];
optional int32 amp_loss_scaling = 3 [ default = 32768 ];
optional bool recompute = 4 [ default = false ];
repeated string recompute_checkpoints = 5;
optional bool localsgd = 6 [ default = false ];
optional int32 localsgd_k_step = 7 [ default = 4 ];
optional bool dgc = 8 [ default = false ];
optional bool hierachical_allreduce = 9 [ default = false ];
optional int32 hierachical_allreduce_inter_ranks = 10 [ default = 1 ];
optional int32 nccl_comm_num = 11 [ default = 1 ];
optional bool gradient_merge = 12 [ default = false ];
optional int32 gradient_merge_k_step = 13 [ default = 1 ];
optional bool sequential_execution = 14 [ default = false ];
optional bool enable_backward_optimizer_op_deps = 15 [ default = true ];
optional bool lars = 16 [ default = false ];
optional bool lamb = 17 [ default = false ];
optional bool fuse_elewise_add_act_ops = 18 [ default = false ];
optional bool fuse_bn_act_ops = 19 [ default = false ];
optional bool enable_auto_fusion = 20 [ default = false ];
optional bool fuse_relu_depthwise_conv = 21 [ default = false ];
optional bool enable_inplace = 22 [ default = false ];
optional bool fuse_all_reduce_ops = 23 [ default = false ];
optional int32 num_iteration_per_drop_scope = 24 [ default = 1 ];
optional bool sync_batch_norm = 25 [ default = false ];
optional bool fuse_all_optimizer_ops = 26 [ default = false ];
optional bool sync_nccl_allreduce = 27 [ default = true ];
optional bool fuse_broadcast_ops = 28 [ default = true ];
optional int32 num_threads = 29 [ default = 1 ];
optional int32 num_iteration_per_run = 30 [ default = 1 ];
message RecomputeConfig { repeated string checkpoints = 1; }
message AMPConfig {
optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 incr_every_n_steps = 2 [ default = 1000 ];
optional int32 decr_every_n_nan_or_inf = 3 [ default = 2 ];
optional float incr_ratio = 4 [ default = 2.0 ];
optional float decr_ratio = 5 [ default = 0.8 ];
optional bool use_dynamic_loss_scaling = 6 [ default = true ];
}
// pipeline training
optional bool pipeline = 101 [ default = false ];
optional int32 pipeline_micro_batch = 102;
message LocalSGDConfig { optional int32 k_steps = 1 [ default = 4 ]; }
message GradientMergeConfig {
optional int32 k_steps = 1 [ default = 1 ];
optional bool avg = 2 [ default = true ];
}
message BuildStrategy {
optional bool enable_sequential_execution = 1 [ default = false ];
optional bool fuse_elewise_add_act_ops = 2 [ default = false ];
optional bool fuse_bn_act_ops = 3 [ default = false ];
optional bool fuse_relu_depthwise_conv = 4 [ default = false ];
optional bool fuse_broadcast_ops = 5 [ default = false ];
optional bool fuse_all_optimizer_ops = 6 [ default = false ];
optional bool enable_inplace = 7 [ default = false ];
optional bool enable_backward_optimizer_op_deps = 8 [ default = true ];
optional bool cache_runtime_context = 9 [ default = false ];
}
// parameter server training
optional bool sync = 201 [ default = false ];
optional bool async = 202 [ default = true ];
optional int32 async_k_step = 203 [ default = -1 ];
optional int32 max_merge_var_num = 204 [ default = 1 ];
optional int32 send_queue_size = 205 [ default = 16 ];
optional bool independent_recv_thread = 206 [ default = false ];
optional int32 min_send_grad_num_before_recv = 207 [ default = 1 ];
optional int32 thread_pool_size = 208 [ default = 1 ];
optional int32 send_wait_times = 209 [ default = 1 ];
optional bool runtime_split_send_recv = 210 [ default = false ];
optional bool use_thread_barrier = 211 [ default = false ];
message ExecutionStrategy {
optional int32 num_threads = 1 [ default = 1 ];
optional int32 num_iteration_per_drop_scope = 2 [ default = 10 ];
optional int32 num_iteration_per_run = 3 [ default = 1 ];
optional bool use_thread_barrier = 4 [ default = false ];
}
message AsyncConfig {
optional int32 k_steps = 1 [ default = 1 ];
optional int32 max_merge_var_num = 2 [ default = 1 ];
optional int32 send_queue_size = 3 [ default = 16 ];
optional bool independent_recv_thread = 4 [ default = false ];
optional int32 min_send_grad_num_before_recv = 5 [ default = 1 ];
optional int32 thread_pool_size = 6 [ default = 1 ];
optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ];
}
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
message DistributedStrategy {
// bool options
optional Mode mode = 1 [ default = COLLECTIVE ];
optional bool amp = 2 [ default = false ];
optional bool recompute = 3 [ default = false ];
optional bool localsgd = 4 [ default = false ];
optional bool dgc = 5 [ default = false ];
optional bool gradient_merge = 6 [ default = false ];
optional bool lars = 7 [ default = false ];
optional bool lamb = 8 [ default = false ];
optional bool pipeline = 9 [ default = false ];
optional bool elastic = 10 [ default = false ];
optional bool auto = 11 [ default = false ];
optional bool async = 12 [ default = true ];
optional bool sync_nccl_allreduce = 13 [ default = true ];
optional int32 nccl_comm_num = 14 [ default = 1 ];
optional bool use_hierarchical_allreduce = 15 [ default = false ];
optional int32 hierarchical_allreduce_inter_nranks = 16 [ default = 1 ];
optional bool sync_batch_norm = 17 [ default = false ];
optional bool fuse_all_reduce_ops = 18 [ default = true ];
// optional bool enable_backward_optimizer_op_deps = 19 [ default = true ];
// elastic deep learning strategies
optional bool elastic = 301 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
optional LocalSGDConfig localsgd_configs = 103;
optional GradientMergeConfig gradient_merge_configs = 104;
optional PipelineConfig pipeline_configs = 106;
optional AsyncConfig async_configs = 107;
// auto parallel
optional bool auto = 401 [ default = false ];
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
}
message DistributedJobInfo {
......
......@@ -18,39 +18,6 @@ from .meta_optimizer_base import MetaOptimizerBase
from ..base.private_helper_function import wait_server_ready
def get_build_strategy(dist_strategy):
build_strategy = paddle.BuildStrategy()
build_strategy.enable_sequential_execution = \
dist_strategy.sequential_execution
build_strategy.remove_unnecessary_lock = True
build_strategy.fuse_elewise_add_act_ops = \
dist_strategy.fuse_elewise_add_act_ops
build_strategy.fuse_bn_act_ops = \
dist_strategy.fuse_bn_act_ops
build_strategy.enable_auto_fusion = \
dist_strategy.enable_auto_fusion
build_strategy.fuse_relu_depthwise_conv = \
dist_strategy.fuse_relu_depthwise_conv
build_strategy.fuse_broadcast_ops = \
dist_strategy.fuse_broadcast_ops
build_strategy.sync_batch_norm = \
dist_strategy.sync_batch_norm
return build_strategy
def get_execution_strategy(dist_strategy):
execution_strategy = paddle.ExecutionStrategy()
execution_strategy.num_threads = \
dist_strategy.num_threads
execution_strategy.num_iteration_per_drop_scope = \
dist_strategy.num_iteration_per_drop_scope
execution_strategy.num_iteration_per_run = \
dist_strategy.num_iteration_per_run
execution_strategy.use_thread_barrier = \
dist_strategy.use_thread_barrier
return execution_strategy
class GraphExecutionOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(GraphExecutionOptimizer, self).__init__(optimizer)
......@@ -76,7 +43,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
pass
# should fix the variable
def _setup_nccl_op(self, startup_program, main_program):
def _setup_nccl_op(self, startup_program, main_program, build_strategy):
trainer_endpoints = self.role_maker.get_trainer_endpoints()
trainer_id = self.role_maker.worker_index()
current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id]
......@@ -88,14 +55,14 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
wait_server_ready(other_trainer_endpoints)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, self.user_defined_strategy.nccl_comm_num):
for i in range(1, build_strategy.nccl_comm_num):
startup_program.global_block().create_var(
name="NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
if self.user_defined_strategy.hierachical_allreduce:
for i in range(0, self.user_defined_strategy.nccl_comm_num):
if build_strategy.use_hierarchical_allreduce:
for i in range(0, build_strategy.nccl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_inter_NCCLID_{}".format(i),
persistable=True,
......@@ -112,48 +79,80 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
attrs={
"trainers": trainer_endpoints,
"trainer_id": trainer_id,
"nccl_comm_num": self.user_defined_strategy.nccl_comm_num,
"nccl_comm_num": build_strategy.nccl_comm_num,
"use_hierarchical_allreduce":
self.user_defined_strategy.hierachical_allreduce,
build_strategy.use_hierarchical_allreduce,
"hierarchical_allreduce_inter_ranks":
self.user_defined_strategy.hierachical_allreduce_inter_ranks
build_strategy.hierarchical_allreduce_inter_nranks
})
def _try_to_compile(self, startup_program, main_program, loss):
build_strategy = get_build_strategy(self.user_defined_strategy)
exe_strategy = get_execution_strategy(self.user_defined_strategy)
import copy
dist_strategy = self.user_defined_strategy
local_build_strategy = paddle.fluid.BuildStrategy()
local_build_strategy.enable_sequential_execution = \
dist_strategy.build_strategy.enable_sequential_execution
local_build_strategy.fuse_elewise_add_act_ops = \
dist_strategy.build_strategy.fuse_elewise_add_act_ops
local_build_strategy.fuse_bn_act_ops = \
dist_strategy.build_strategy.fuse_bn_act_ops
local_build_strategy.enable_auto_fusion = \
dist_strategy.build_strategy.enable_auto_fusion
local_build_strategy.fuse_relu_depthwise_conv = \
dist_strategy.build_strategy.fuse_relu_depthwise_conv
local_build_strategy.fuse_broadcast_ops = \
dist_strategy.build_strategy.fuse_broadcast_ops
local_build_strategy.fuse_all_optimizer_ops = \
dist_strategy.build_strategy.fuse_all_optimizer_ops
local_build_strategy.enable_inplace = \
dist_strategy.build_strategy.enable_inplace
local_build_strategy.use_hierarchical_allreduce = \
dist_strategy.use_hierarchical_allreduce
local_build_strategy.hierarchical_allreduce_inter_nranks = \
dist_strategy.hierarchical_allreduce_inter_nranks
local_build_strategy.sync_batch_norm = \
dist_strategy.sync_batch_norm
local_build_strategy.fuse_all_reduce_ops = \
dist_strategy.fuse_all_reduce_ops
local_build_strategy.nccl_comm_num = \
dist_strategy.nccl_comm_num
exe_strategy = self.user_defined_strategy.execution_strategy
node_num = self.role_maker.worker_num()
if self.role_maker._is_collective:
assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num
if node_num <= 1:
# local mode
if self.user_defined_strategy.nccl_comm_num > 1:
if local_build_strategy.nccl_comm_num > 1:
logging.warn("set nccl_comm_num=1 since you only have 1 node.")
self.user_defined_strategy.nccl_comm_num = 1
local_build_strategy.nccl_comm_num = 1
if self.user_defined_strategy.hierachical_allreduce:
if local_build_strategy.use_hierarchical_allreduce:
logging.warn(
"set hierachical_allreduce=False since you only have 1 node."
)
self.user_defined_strategy.hierachical_allreduce = False
local_build_strategy.use_hierarchical_allreduce = False
sync_allreduce = self.user_defined_strategy.sync_nccl_allreduce
sync_allreduce = dist_strategy.sync_nccl_allreduce
if sync_allreduce:
exe_strategy.num_threads = self.user_defined_strategy.nccl_comm_num + 1
if self.user_defined_strategy.hierachical_allreduce:
exe_strategy.num_threads = 2 * self.user_defined_strategy.nccl_comm_num + 1
paddle.fluid.framework.set_flags({
"FLAGS_sync_nccl_allreduce": True
})
exe_strategy.num_threads = local_build_strategy.nccl_comm_num + 1
if local_build_strategy.use_hierarchical_allreduce:
exe_strategy.num_threads = 2 * local_build_strategy.nccl_comm_num + 1
if exe_strategy.num_threads > 4:
logging.warn(
"if you use hierachical_allreduce or "
"with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
"with multi nccl comm, please set distributed_strategy.sync_nccl_allreduce=False"
)
# TODO(guru4elephant): should be an independent optimizer
sync_batch_norm = self.user_defined_strategy.sync_batch_norm
sync_batch_norm = local_build_strategy.sync_batch_norm
if sync_batch_norm:
self.user_defined_strategy.nccl_comm_num = 1
self.user_defined_strategy.hierachical_allreduce = False
local_build_strategy.nccl_comm_num = 1
local_build_strategy.use_hierarchical_allreduce = False
exe_strategy.num_threads = 1
logging.warn(
"use sync_batch_norm will hang when set num_threads > 1, so "
......@@ -161,19 +160,19 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
)
# TODO(guru4elephant): should be an independent optimizer
self._setup_nccl_op(startup_program, main_program)
self._setup_nccl_op(startup_program, main_program, local_build_strategy)
build_strategy.num_trainers = self.role_maker.worker_num()
build_strategy.trainer_id = self.role_maker.worker_index()
build_strategy.trainers_endpoints = self.role_maker.get_trainer_endpoints(
local_build_strategy.num_trainers = self.role_maker.worker_num()
local_build_strategy.trainer_id = self.role_maker.worker_index()
local_build_strategy.trainers_endpoints = self.role_maker.get_trainer_endpoints(
)
build_strategy.enable_backward_optimizer_op_deps = True
local_build_strategy.enable_backward_optimizer_op_deps = True
self._compiled_program = compiler.CompiledProgram(main_program)
self._compiled_program.with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
build_strategy=local_build_strategy,
exec_strategy=exe_strategy,
share_vars_from=None)
......@@ -188,7 +187,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
startup_program = paddle.default_startup_program()
compiled_program = self._try_to_compile(startup_program,
loss.block.program, loss)
loss.block.program.graph = compiled_program
loss.block.program._graph = compiled_program
# just return self.optimizer_ops and self.param_grads
return None, None
......@@ -3936,6 +3936,9 @@ class Program(object):
# appending gradients times
self._appending_grad_times = 0
# compiled program, i.e. Graph
self._graph = None
def global_seed(self, seed=0):
"""
Set global seed for Program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册