From 920d998f1ee941ed8facf41f34674e17e07b4377 Mon Sep 17 00:00:00 2001 From: Dong Daxiang <35550832+guru4elephant@users.noreply.github.com> Date: Tue, 28 Jul 2020 12:47:46 +0800 Subject: [PATCH] add more settings for distributed strategy (#25685) * add more settings for distributed strategy Basically, DistributedStrategy has several parts of configurations: - BuildStrategy: the same as paddle.fluid.BuildStrategy, but the distributed arguments are moved out of BuildStrategy - ExecutionStrategy: the same as paddle.fluid.ExecutionStrategy - collective communication configs: nccl_comm_num, hierarchical allreduce and so on - distributed algorithms: async_update(mainly used in PS), lars, lamb and so on --- .../framework/distributed_strategy.proto | 128 +-- .../paddle/fleet/base/distributed_strategy.py | 761 +++++++++--------- .../graph_execution_optimizer.py | 123 ++- python/paddle/fluid/framework.py | 3 + .../test_fleet_distributed_strategy.py | 387 ++++----- 5 files changed, 663 insertions(+), 739 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index cc7d60b148d..d547800bf6c 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -22,61 +22,87 @@ enum Mode { HETER = 4; // support XPU and GPU computing server } -message DistributedStrategy { - optional Mode mode = 1 [ default = COLLECTIVE ]; // just for serialization - // collective training strategy - optional bool amp = 2 [ default = false ]; - optional int32 amp_loss_scaling = 3 [ default = 32768 ]; - optional bool recompute = 4 [ default = false ]; - repeated string recompute_checkpoints = 5; - optional bool localsgd = 6 [ default = false ]; - optional int32 localsgd_k_step = 7 [ default = 4 ]; - optional bool dgc = 8 [ default = false ]; - optional bool hierachical_allreduce = 9 [ default = false ]; - optional int32 hierachical_allreduce_inter_ranks = 10 [ default = 1 ]; - optional int32 nccl_comm_num = 11 [ default = 1 ]; - optional bool gradient_merge = 12 [ default = false ]; - optional int32 gradient_merge_k_step = 13 [ default = 1 ]; - optional bool sequential_execution = 14 [ default = false ]; - optional bool enable_backward_optimizer_op_deps = 15 [ default = true ]; - optional bool lars = 16 [ default = false ]; - optional bool lamb = 17 [ default = false ]; - optional bool fuse_elewise_add_act_ops = 18 [ default = false ]; - optional bool fuse_bn_act_ops = 19 [ default = false ]; - optional bool enable_auto_fusion = 20 [ default = false ]; - optional bool fuse_relu_depthwise_conv = 21 [ default = false ]; - optional bool enable_inplace = 22 [ default = false ]; - optional bool fuse_all_reduce_ops = 23 [ default = false ]; - optional int32 num_iteration_per_drop_scope = 24 [ default = 1 ]; - optional bool sync_batch_norm = 25 [ default = false ]; - optional bool fuse_all_optimizer_ops = 26 [ default = false ]; - optional bool sync_nccl_allreduce = 27 [ default = true ]; - optional bool fuse_broadcast_ops = 28 [ default = true ]; - optional int32 num_threads = 29 [ default = 1 ]; - optional int32 num_iteration_per_run = 30 [ default = 1 ]; +message RecomputeConfig { repeated string checkpoints = 1; } + +message AMPConfig { + optional float init_loss_scaling = 1 [ default = 32768.0 ]; + optional int32 incr_every_n_steps = 2 [ default = 1000 ]; + optional int32 decr_every_n_nan_or_inf = 3 [ default = 2 ]; + optional float incr_ratio = 4 [ default = 2.0 ]; + optional float decr_ratio = 5 [ default = 0.8 ]; + optional bool use_dynamic_loss_scaling = 6 [ default = true ]; +} - // pipeline training - optional bool pipeline = 101 [ default = false ]; - optional int32 pipeline_micro_batch = 102; +message LocalSGDConfig { optional int32 k_steps = 1 [ default = 4 ]; } + +message GradientMergeConfig { + optional int32 k_steps = 1 [ default = 1 ]; + optional bool avg = 2 [ default = true ]; +} + +message BuildStrategy { + optional bool enable_sequential_execution = 1 [ default = false ]; + optional bool fuse_elewise_add_act_ops = 2 [ default = false ]; + optional bool fuse_bn_act_ops = 3 [ default = false ]; + optional bool fuse_relu_depthwise_conv = 4 [ default = false ]; + optional bool fuse_broadcast_ops = 5 [ default = false ]; + optional bool fuse_all_optimizer_ops = 6 [ default = false ]; + optional bool enable_inplace = 7 [ default = false ]; + optional bool enable_backward_optimizer_op_deps = 8 [ default = true ]; + optional bool cache_runtime_context = 9 [ default = false ]; +} - // parameter server training - optional bool sync = 201 [ default = false ]; - optional bool async = 202 [ default = true ]; - optional int32 async_k_step = 203 [ default = -1 ]; - optional int32 max_merge_var_num = 204 [ default = 1 ]; - optional int32 send_queue_size = 205 [ default = 16 ]; - optional bool independent_recv_thread = 206 [ default = false ]; - optional int32 min_send_grad_num_before_recv = 207 [ default = 1 ]; - optional int32 thread_pool_size = 208 [ default = 1 ]; - optional int32 send_wait_times = 209 [ default = 1 ]; - optional bool runtime_split_send_recv = 210 [ default = false ]; - optional bool use_thread_barrier = 211 [ default = false ]; +message ExecutionStrategy { + optional int32 num_threads = 1 [ default = 1 ]; + optional int32 num_iteration_per_drop_scope = 2 [ default = 10 ]; + optional int32 num_iteration_per_run = 3 [ default = 1 ]; + optional bool use_thread_barrier = 4 [ default = false ]; +} + +message AsyncConfig { + optional int32 k_steps = 1 [ default = 1 ]; + optional int32 max_merge_var_num = 2 [ default = 1 ]; + optional int32 send_queue_size = 3 [ default = 16 ]; + optional bool independent_recv_thread = 4 [ default = false ]; + optional int32 min_send_grad_num_before_recv = 5 [ default = 1 ]; + optional int32 thread_pool_size = 6 [ default = 1 ]; + optional int32 send_wait_times = 7 [ default = 1 ]; + optional bool runtime_split_send_recv = 8 [ default = false ]; +} + +message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } + +message DistributedStrategy { + // bool options + optional Mode mode = 1 [ default = COLLECTIVE ]; + optional bool amp = 2 [ default = false ]; + optional bool recompute = 3 [ default = false ]; + optional bool localsgd = 4 [ default = false ]; + optional bool dgc = 5 [ default = false ]; + optional bool gradient_merge = 6 [ default = false ]; + optional bool lars = 7 [ default = false ]; + optional bool lamb = 8 [ default = false ]; + optional bool pipeline = 9 [ default = false ]; + optional bool elastic = 10 [ default = false ]; + optional bool auto = 11 [ default = false ]; + optional bool async = 12 [ default = true ]; + optional bool sync_nccl_allreduce = 13 [ default = true ]; + optional int32 nccl_comm_num = 14 [ default = 1 ]; + optional bool use_hierarchical_allreduce = 15 [ default = false ]; + optional int32 hierarchical_allreduce_inter_nranks = 16 [ default = 1 ]; + optional bool sync_batch_norm = 17 [ default = false ]; + optional bool fuse_all_reduce_ops = 18 [ default = true ]; + // optional bool enable_backward_optimizer_op_deps = 19 [ default = true ]; - // elastic deep learning strategies - optional bool elastic = 301 [ default = false ]; + optional RecomputeConfig recompute_configs = 101; + optional AMPConfig amp_configs = 102; + optional LocalSGDConfig localsgd_configs = 103; + optional GradientMergeConfig gradient_merge_configs = 104; + optional PipelineConfig pipeline_configs = 106; + optional AsyncConfig async_configs = 107; - // auto parallel - optional bool auto = 401 [ default = false ]; + optional BuildStrategy build_strategy = 201; + optional ExecutionStrategy execution_strategy = 202; } message DistributedJobInfo { diff --git a/python/paddle/fleet/base/distributed_strategy.py b/python/paddle/fleet/base/distributed_strategy.py index fdc5b22ae4c..74629cef615 100644 --- a/python/paddle/fleet/base/distributed_strategy.py +++ b/python/paddle/fleet/base/distributed_strategy.py @@ -12,11 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle.fleet.proto import distributed_strategy_pb2 from paddle.fluid.framework import Variable import google.protobuf.text_format +def get_msg_dict(msg): + res_dict = {} + fields = msg.DESCRIPTOR.fields + for f in fields: + res_dict[f.name] = getattr(msg, f.name) + return res_dict + + +def assign_configs_value(msg, config): + fields = msg.DESCRIPTOR.fields + for key in config: + for f in fields: + if key == f.name: + if f.label == 3: + getattr(msg, f.name).extend(config[f.name]) + elif f.label == 1 or f.label == 2: + setattr(msg, f.name, config[f.name]) + + +def check_configs_key(msg, config, field_name): + key_list = msg.DESCRIPTOR.fields_by_name.keys() + for key in config: + assert key in key_list, "key:{} not in {}".format(key, field_name) + + class DistributedJobInfo(object): """ DistributedJobInfo will serialize all distributed training information @@ -56,207 +82,241 @@ class DistributedJobInfo(object): class DistributedStrategy(object): def __init__(self): + """ + DistributedStrategy is the main configuration entry for distributed training of Paddle. + All of the distributed training configurations can be configured in DistributedStrategy, + such as automatic mixed precision (AMP), Layer-wise Adaptive Rate Scaling (LARS), + asynchronous update parameter server(ASGD), etc. + + DistributedStrategy can be serialized into protobuf file or deserialized from protobuf file + + Users who run local training usually configure BuildStrategy and ExecutionStrategy, and + DistributedStrategy supports configurations from BuildStrategy and ExecutionStrategy + + """ self.strategy = distributed_strategy_pb2.DistributedStrategy() def save_to_prototxt(self, output): + """ + Serialize current DistributedStrategy to string and save to output file + + Examples: + .. code-block:: python + + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.dgc = True + strategy.recompute = True + strategy.recompute_configs = {"checkpoint": ["x"]} + strategy.save_to_prototxt("dist_strategy.prototxt") + """ with open(output, "w") as fout: fout.write(str(self.strategy)) def load_from_prototxt(self, pb_file): - f = open(pb_file, 'r') - self.strategy = google.protobuf.text_format.Merge( - str(f.read()), self.strategy) - - @property - def amp(self): - return self.strategy.amp - - @amp.setter - def amp(self, flag): + """ + Load from prototxt file for DistributedStrategy initialization + + Examples: + .. code-block:: python + + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.load_from_prototxt("dist_strategy.protoxt") + """ + with open(pb_file, 'r') as f: + self.strategy = google.protobuf.text_format.Merge( + str(f.read()), self.strategy) + + @property + def execution_strategy(self): + """ + Configure ExecutionStrategy for DistributedStrategy + + Examples: + .. code-block:: python + + exe_strategy = paddle.fluid.ExecutionStrategy() + exe_strategy.num_threads = 10 + exe_strategy.num_iteration_per_drop_scope = 10 + exe_strategy.num_iteration_per_run = 10 + + strategy = paddle.fleet.DistributedStrategy() + strategy.execution_strategy = exe_strategy + """ + execution_strategy = paddle.fluid.ExecutionStrategy() + fields = self.strategy.execution_strategy.DESCRIPTOR.fields + for f in fields: + setattr(execution_strategy, f.name, + getattr(self.strategy.execution_strategy, f.name)) + return execution_strategy + + @execution_strategy.setter + def execution_strategy(self, strategy): + fields = self.strategy.execution_strategy.DESCRIPTOR.fields + for f in fields: + setattr(self.strategy.execution_strategy, f.name, + getattr(strategy, f.name)) + + @property + def build_strategy(self): + """ + Configure BuildStrategy for DistributedStrategy + Note that the properties of BuildStrategy are valid in DistributedStrategy + only if the property is non-distributed strategy. + + Examples: + .. code-block:: python + + build_strategy = paddle.fluid.BuildStrategy() + build_strategy.enable_sequential_execution = True + build_strategy.fuse_elewise_add_act_ops = True + build_strategy.fuse_bn_act_ops = True + build_strategy.enable_auto_fusion = True + build_strategy.fuse_relu_depthwise_conv = True + build_strategy.fuse_broadcast_ops = True + build_strategy.fuse_all_optimizer_ops = True + build_strategy.enable_inplace = True + + strategy = paddle.fleet.DistributedStrategy() + strategy.build_strategy = build_strategy + """ + + build_strategy = paddle.fluid.BuildStrategy() + fields = self.strategy.build_strategy.DESCRIPTOR.fields + for f in fields: + setattr(build_strategy, f.name, + getattr(self.strategy.build_strategy, f.name)) + return build_strategy + + @build_strategy.setter + def build_strategy(self, strategy): + fields = self.strategy.build_strategy.DESCRIPTOR.fields + for f in fields: + if f.label == 1 or f.label == 2: # optional and required field + setattr(self.strategy.build_strategy, f.name, + getattr(strategy, f.name)) + elif f.label == 3: # repeated field + getattr(self.strategy.build_strategy, + f.name).extend(getattr(strategy, f.name)) + + @property + def async_update(self): + """ + Indicating whether we are using asynchronous stocastic gradient descent updates + for training. This property is valid when we are using parameter server training, + which is implied by setting approperate RoleMaker + Default value: True + + Examples: + .. code-block:: python + + import paddle.fleet as fleet + role_maker = fleet.PaddleCloudRoleMaker() + fleet.init(role_maker) + + strategy = fleet.DistributedStrategy() + strategy.async_update = True # by default this is True + + # code block for defining loss and local optimizer + # sgd = fleet.distributed_optimizer(optimizer, strategy) + """ + return self.strategy.async + + @async_update.setter + def async_update(self, flag): if isinstance(flag, bool): - self.strategy.amp = flag - else: - print("WARNING: amp should have value of bool type") - - @property - def amp_loss_scaling(self): - return self.strategy.amp_loss_scaling - - @amp_loss_scaling.setter - def amp_loss_scaling(self, value): - if isinstance(value, int): - self.strategy.amp_loss_scaling = value + self.strategy.async = flag else: - print("WARNING: amp_loss_scaling should have value of int type") + print("WARNING: async_update should have value of bool type") @property - def recompute(self): - return self.strategy.recompute + def async_update_configs(self): + """ + Set async update configurations. In general, asynchronous parameter server + training has serveral configurable settings that can be configured through + a dict. - @recompute.setter - def recompute(self, flag): - if isinstance(flag, bool): - self.strategy.recompute = flag - else: - print("WARNING: recompute should have value of bool type") + **Notes**: + **Detailed arguments for async_update_configs** + **k_step**: number of local optimization updates before communication + **max_merge_var_num**: maximum number of merged gradients before communication + **send_queue_size**: a buffer size of worker communication + **independent_recv_thread**: if we are using independent recv thread for communication + **thread_pool_size**: number of thread pool + **send_wait_times**: waiting time for sending gradients + **runtime_split_send_recv**: if we are using Tensor split for send and recv during runtime - @property - def recompute_checkpoints(self): - return self.strategy.recompute_checkpoints - - @recompute_checkpoints.setter - def recompute_checkpoints(self, checkpoints): - if isinstance(checkpoints, list): - str_list = True - var_list = True - for item in checkpoints: - if not isinstance(item, str): - str_list = False - if not isinstance(item, Variable): - var_list = False - - assert (str_list and var_list) == False - if str_list: - self.strategy.ClearField("recompute_checkpoints") - self.strategy.recompute_checkpoints.extend(checkpoints) - elif var_list: - names = [x.name for x in checkpoints] - self.strategy.ClearField("recompute_checkpoints") - self.strategy.recompute_checkpoints.extend(names) - else: - print( - "WARNING: recompute_checkpoints should have value of list[Variable] or list[name] type" - ) - else: - print( - "WARNING: recompute_checkpoints should have value of list[Variable] or list[name] type" - ) - - @property - def pipeline(self): - return self.strategy.pipeline - - @pipeline.setter - def pipeline(self, flag): - if isinstance(flag, bool): - self.strategy.pipeline = flag - else: - print("WARNING: pipeline should have value of bool type") - - @property - def pipeline_micro_batch(self): - return self.strategy.pipeline_micro_batch - - @pipeline_micro_batch.setter - def pipeline_micro_batch(self, value): - if isinstance(value, int): - self.strategy.pipeline_micro_batch = value - else: - print("WARNING: pipeline micro batch should have value of int type") - - @property - def localsgd(self): - return self.strategy.localsgd - - @localsgd.setter - def localsgd(self, flag): - if isinstance(flag, bool): - self.strategy.localsgd = flag - else: - print("WARNING: localsgd should have value of bool type") - - @property - def localsgd_k_step(self): - return self.strategy.localsgd_k_step - - @localsgd_k_step.setter - def localsgd_k_step(self, value): - if isinstance(value, int): - self.strategy.localsgd_k_step = value - else: - print("WARNING: localsgd_k_step should have value of int type") + Examples: + .. code-block:: python - @property - def dgc(self): - return self.strategy.dgc + import paddle.fleet as fleet + role_maker = fleet.PaddleCloudRoleMaker() + fleet.init(role_maker) - @dgc.setter - def dgc(self, flag): - if isinstance(flag, bool): - self.strategy.dgc = flag - else: - print("WARNING: dgc should have value of bool type") + strategy = fleet.DistributedStrategy() + strategy.async_update = True # by default this is True + configs = {"k_step": 10000, "send_queue_size": 32} + strategy.async_update_configs = configs - @property - def hierachical_allreduce(self): - return self.strategy.hierachical_allreduce + # code block for defining loss and local optimizer + # sgd = fleet.distributed_optimizer(optimizer, strategy) + """ + return get_msg_dict(self.strategy.async_configs) - @hierachical_allreduce.setter - def hierachical_allreduce(self, flag): - if isinstance(flag, bool): - self.strategy.hierachical_allreduce = flag - else: - print( - "WARNING: hierachical_allreduce should have value of bool type") + @async_update_configs.setter + def async_update_configs(self, configs): + check_configs_key(self.strategy.async_configs, configs, "async_configs") + assign_configs_value(self.strategy.async_configs, configs) @property - def hierachical_allreduce_inter_ranks(self): - return self.strategy.hierachical_allreduce_inter_ranks - - @hierachical_allreduce_inter_ranks.setter - def hierachical_allreduce_inter_ranks(self, flag): - if isinstance(flag, bool): - self.strategy.hierachical_allreduce_inter_ranks = flag - else: - print( - "WARNING: hierachical_allreduce_inter_ranks should have value of bool type" - ) + def amp(self): + """ + Indicating whether we are using automatic mixed precision training + Default Value: False - @property - def nccl_comm_num(self): - return self.strategy.nccl_comm_num + Examples: + .. code-block:: python - @nccl_comm_num.setter - def nccl_comm_num(self, value): - if isinstance(value, int): - self.strategy.nccl_comm_num = value - else: - print("WARNING: nccl_comm_num should have value of int type") + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.amp = True # by default this is false - @property - def gradient_merge(self): - return self.strategy.gradient_merge + """ + return self.strategy.amp - @gradient_merge.setter - def gradient_merge(self, flag): + @amp.setter + def amp(self, flag): if isinstance(flag, bool): - self.strategy.gradient_merge = flag + self.strategy.amp = flag else: - print("WARNING: gradient_merge should have value of bool type") + print("WARNING: amp should have value of bool type") @property - def gradient_merge_k_step(self): - return self.strategy.gradient_merge_k_step + def amp_configs(self): + return get_msg_dict(self.strategy.amp_configs) - @gradient_merge_k_step.setter - def gradient_merge_k_step(self, value): - if isinstance(value, int): - self.strategy.gradient_merge_k_step = value - else: - print( - "WARNING: gradient_merge_k_step should have value of int type") + @amp_configs.setter + def amp_configs(self, configs): + check_configs_key(self.strategy.amp_configs, configs, "amp_configs") + assign_configs_value(self.strategy.amp_configs, configs) @property - def sequential_execution(self): - return self.strategy.sequential_execution - - @sequential_execution.setter - def sequential_execution(self, flag): - if isinstance(flag, bool): - self.strategy.sequential_execution = flag - else: - print( - "WARNING: sequential_execution should have value of bool type") + def recompute(self): + """ + Indicating whether we are using forward recomputation for memory optimization + Default value: False + + Examples: + .. code-block:: python + + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.recompute = True + # suppose x and y are names of checkpoint tensors for recomputation + strategy.recompute_configs = {"checkpoints": ["x", "y"]} + """ + return self.strategy.recompute @property def sync_nccl_allreduce(self): @@ -267,99 +327,44 @@ class DistributedStrategy(object): if isinstance(flag, bool): self.strategy.sync_nccl_allreduce = flag else: - print("WARNING: sync_nccl_allreduce should have avlue of bool type") + print("WARNING: sync_nccl_allreduce should have value of bool type") @property - def lars(self): - return self.strategy.lars + def use_hierarchical_allreduce(self): + return self.strategy.use_hierarchical_allreduce - @lars.setter - def lars(self, flag): + @use_hierarchical_allreduce.setter + def use_hierarchical_allreduce(self, flag): if isinstance(flag, bool): - self.strategy.lars = flag - else: - print("WARNING: lars should have value of bool type") - - @property - def lamb(self): - return self.strategy.lamb - - @lamb.setter - def lamb(self, flag): - if isinstance(flag, bool): - self.strategy.lamb = flag - else: - print("WARNING: lamb should have value of bool type") - - @property - def fuse_elewise_add_act_ops(self): - return self.strategy.fuse_elewise_add_act_ops - - @fuse_elewise_add_act_ops.setter - def fuse_elewise_add_act_ops(self, flag): - if isinstance(flag, bool): - self.strategy.fuse_elewise_add_act_ops = flag + self.strategy.use_hierarchical_allreduce = flag else: print( - "WARNING: fuse_elewise_add_act_ops should have value of bool type" + "WARNING: use_hierarchical_allreduce should have value of bool type" ) @property - def fuse_bn_act_ops(self): - return self.strategy.fuse_bn_act_ops - - @fuse_bn_act_ops.setter - def fuse_bn_act_ops(self, flag): - if isinstance(flag, bool): - self.strategy.fuse_bn_act_ops = flag - else: - print("WARNING: fuse_bn_act_ops should have value of bool type") - - @property - def enable_auto_fusion(self): - return self.strategy.enable_auto_fusion - - @enable_auto_fusion.setter - def enable_auto_fusion(self, flag): - if isinstance(flag, bool): - self.strategy.enable_auto_fusion = flag - else: - print("WARNING: enable_auto_fusion should have value of bool type") - - @property - def fuse_relu_depthwise_conv(self): - return self.strategy.fuse_relu_depthwise_conv + def hierarchical_allreduce_inter_nranks(self): + return self.strategy.hierarchical_allreduce_inter_nranks - @fuse_relu_depthwise_conv.setter - def fuse_relu_depthwise_conv(self, flag): - if isinstance(flag, bool): - self.strategy.fuse_relu_depthwise_conv = flag + @hierarchical_allreduce_inter_nranks.setter + def hierarchical_allreduce_inter_nranks(self, value): + if isinstance(value, int): + self.strategy.hierarchical_allreduce_inter_nranks = value else: print( - "WARNING: fuse_relu_depthwise_conv should have value of bool type" + "WARNING: hierarchical_allreduce_inter_nranks should have value of int type" ) @property - def fuse_broadcast_ops(self): - return self.strategy.fuse_broadcast_ops - - @fuse_broadcast_ops.setter - def fuse_broadcast_ops(self, flag): - if isinstance(flag, bool): - self.strategy.fuse_broadcast_ops = flag - else: - print("WARNING: fuse_broadcast_ops should have value of bool type") - - @property - def enable_inplace(self): - return self.strategy.enable_inplace + def sync_batch_norm(self): + return self.strategy.sync_batch_norm - @enable_inplace.setter - def enable_inplace(self, flag): + @sync_batch_norm.setter + def sync_batch_norm(self, flag): if isinstance(flag, bool): - self.strategy.enable_inplace = flag + self.strategy.sync_batch_norm = flag else: - print("WARNING: enable_inplace should have value of bool type") + print("WARNING: sync_batch_norm should have value of bool type") @property def fuse_all_reduce_ops(self): @@ -373,177 +378,188 @@ class DistributedStrategy(object): print("WARNING: fuse_all_reduce_ops should have value of bool type") @property - def num_iteration_per_drop_scope(self): - return self.strategy.num_iteration_per_drop_scope - - @num_iteration_per_drop_scope.setter - def num_iteration_per_drop_scope(self, flag): - if isinstance(flag, int): - self.strategy.num_iteration_per_drop_scope = flag - else: - print( - "WARNING: num_iteration_per_drop_scope should have value of int type" - ) - - @property - def num_iteration_per_run(self): - return self.strategy.num_iteration_per_run + def nccl_comm_num(self): + return self.strategy.nccl_comm_num - @num_iteration_per_run.setter - def num_iteration_per_run(self, value): + @nccl_comm_num.setter + def nccl_comm_num(self, value): if isinstance(value, int): - self.strategy.num_iteration_per_run = value + self.strategy.nccl_comm_num = value else: - print( - "WARNING: num_iteration_per_run should have value of int type") - - @property - def sync_batch_norm(self): - return self.strategy.sync_batch_norm + print("WARNING: nccl_comm_num should have value of int type") - @sync_batch_norm.setter - def sync_batch_norm(self, flag): + @recompute.setter + def recompute(self, flag): if isinstance(flag, bool): - self.strategy.sync_batch_norm = flag + self.strategy.recompute = flag else: - print("WARNING: sync_batch_norm should have value of bool type") + print("WARNING: recompute should have value of bool type") @property - def fuse_all_optimizer_ops(self): - return self.strategy.fuse_all_optimizer_ops + def recompute_configs(self): + """ + Set recompute configurations. In general, the recompute strategy of current + implementation should have some manually assign checkpoints - @fuse_all_optimizer_ops.setter - def fuse_all_optimizer_ops(self, flag): - if isinstance(flag, bool): - self.strategy.fuse_all_optimizer_ops = flag - else: - print( - "WARNING: fuse_all_optimizer_ops should have value of bool type") + Examples: + .. code-block:: python + + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.recompute = True + strategy.recompute_configs = {"checkpionts": ["x", "y"]} + + """ + return get_msg_dict(self.strategy.recompute_configs) + + @recompute_configs.setter + def recompute_configs(self, configs): + check_configs_key(self.strategy.recompute_configs, configs, + "checkpoint_configs") + assign_configs_value(self.strategy.recompute_configs, configs) @property - def sync(self): - return self.strategy.sync + def pipeline(self): + """ + Indicating whether we are using pipeline parallelism for distributed training. + Current implementation mainly focus on single GPU machine pipeline parallelism and + data parallelism across GPU machine. The pipeline information is indicated through + device_guard information in user-defined program. + + Examples: + .. code-block:: python + + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.pipeline = True + + """ + return self.strategy.pipeline - @sync.setter - def sync(self, flag): + @pipeline.setter + def pipeline(self, flag): if isinstance(flag, bool): - self.strategy.sync = flag + self.strategy.pipeline = flag else: - print("WARNING: sync should have value of bool type") + print("WARNING: pipeline should have value of bool type") @property - def async_k_step(self): - return self.strategy.async_k_step + def pipeline_configs(self): + """ + Set pipeline parallelism configurations. In pipeline parallelism, + different parts of neural networks are running on different GPUS. + There are Tensor queue buffer between each pair of neighborhood GPUS + that are responsible for synchronizing hidden Tensor results between + GPUs. Pipeline parallelism consists of serveral producer-consumer style + hardware pairs, such as GPU-GPU, CPU-GPU, GPU-XPU. The best way to speedup + pipeline parallelism is to make the size of Tensor in Tensor queue smaller, + so that we will have a faster producer for downstream consumers. - @async_k_step.setter - def async_k_step(self, value): - if isinstance(value, int): - self.strategy.async_k_step = value - else: - print("WARNING: async_k_step should have value of int type") + **Notes**: + **Detailed arguments for pipeline_configs** + **micro_batch**: the number of small batches in each user defined batch - @property - def max_merge_var_num(self): - return self.strategy.max_merge_var_num + Examples: + .. code-block:: python + + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.pipeline = True + strategy.pipeline_configs = {"micro_batch": 12} - @max_merge_var_num.setter - def max_merge_var_num(self, value): - if isinstance(value, int): - self.strategy.max_merge_var_num = value - else: - print("WARNING: max_merge_var_num should have value of int type") + """ - @property - def send_queue_size(self): - return self.strategy.send_queue_size + return get_msg_dict(self.strategy.pipeline_configs) - @send_queue_size.setter - def send_queue_size(self, value): - if isinstance(value, int): - self.strategy.send_queue_size = value - else: - print("WARNING: send_queue_size should have value of int type") + @pipeline_configs.setter + def pipeline_configs(self, configs): + check_configs_key(self.strategy.pipeline_configs, configs, + "pipeline_configs") + assign_configs_value(self.strategy.pipeline_configs, configs) @property - def independent_recv_thread(self): - return self.strategy.independent_recv_thread + def localsgd(self): + return self.strategy.localsgd - @independent_recv_thread.setter - def independent_recv_thread(self, value): - if isinstance(value, bool): - self.strategy.independent_recv_thread = value + @localsgd.setter + def localsgd(self, flag): + if isinstance(flag, bool): + self.strategy.localsgd = flag else: - print( - "WARNING: independent_recv_thread should have value of int type") + print("WARNING: localsgd should have value of bool type") @property - def min_send_grad_num_before_recv(self): - return self.strategy.min_send_grad_num_before_recv + def localsgd_configs(self): + return get_msg_dict(self.strategy.localsgd_configs) - @min_send_grad_num_before_recv.setter - def min_send_grad_num_before_recv(self, value): - if isinstance(value, int): - self.strategy.min_send_grad_num_before_recv = value - else: - print( - "WARNING: min_send_grad_num_before_recv should have value of int type" - ) + @localsgd_configs.setter + def localsgd_configs(self, configs): + check_configs_key(self.strategy.localsgd_configs, configs, + "localsgd_configs") + assign_configs_value(self.strategy.localsgd_configs, configs) @property - def thread_pool_size(self): - return self.strategy.thread_pool_size + def dgc(self): + return self.strategy.dgc - @thread_pool_size.setter - def thread_pool_size(self, value): - if isinstance(value, int): - self.strategy.thread_pool_size = value + @dgc.setter + def dgc(self, flag): + if isinstance(flag, bool): + self.strategy.dgc = flag else: - print("WARNING:thread_pool_size should have value of int type") + print("WARNING: dgc should have value of bool type") @property - def send_wait_times(self): - return self.strategy.send_wait_times + def dgc_configs(self): + return get_msg_dict(self.strategy.dgc_configs) - @send_wait_times.setter - def send_wait_times(self, value): - if isinstance(value, int): - self.strategy.send_wait_times = value - else: - print("WARNING: send_wait_times should have value of int type") + @dgc_configs.setter + def dgc_configs(self, configs): + check_configs_key(self.strategy.dgc_configs, configs, "dgc_configs") + assign_configs_value(self.strategy.dgc_configs, configs) @property - def runtime_split_send_recv(self): - return self.strategy.runtime_split_send_recv + def gradient_merge(self): + return self.strategy.gradient_merge - @runtime_split_send_recv.setter - def runtime_split_send_recv(self, flag): + @gradient_merge.setter + def gradient_merge(self, flag): if isinstance(flag, bool): - self.strategy.runtime_split_send_recv = flag + self.strategy.gradient_merge = flag else: - print("WARNING: runtime_split_send_recv should be bool type") + print("WARNING: gradient_merge should have value of bool type") + + @property + def gradient_merge_configs(self): + return get_msg_dict(self.strategy.gradient_merge_configs) + + @gradient_merge_configs.setter + def gradient_merge_configs(self, configs): + check_configs_key(self.strategy.gradient_merge_configs, configs, + "gradient_configs") + assign_configs_value(self.strategy.gradient_merge_configs, configs) @property - def use_thread_barrier(self): - return self.strategy.use_thread_barrier + def lars(self): + return self.strategy.lars - @use_thread_barrier.setter - def use_thread_barrier(self, flag): + @lars.setter + def lars(self, flag): if isinstance(flag, bool): - self.strategy.use_thread_barrier = flag + self.strategy.lars = flag else: - print("WARNING: use_thread_barrier should be bool type") + print("WARNING: lars should have value of bool type") @property - def enable_backward_optimizer_op_deps(self): - return self.strategy.enable_backward_optimizer_op_deps + def lamb(self): + return self.strategy.lamb - @enable_backward_optimizer_op_deps.setter - def enable_backward_optimizer_op_deps(self, flag): + @lamb.setter + def lamb(self, flag): if isinstance(flag, bool): - self.strategy.enable_backward_optimizer_op_deps = flag + self.strategy.lamb = flag else: - print( - "WARNING: enable_backward_optimizer_op_deps should be bool type") + print("WARNING: lamb should have value of bool type") @property def elastic(self): @@ -556,17 +572,6 @@ class DistributedStrategy(object): else: print("WARNING: elastic should have value of bool type") - @property - def num_threads(self): - return self.strategy.num_threads - - @num_threads.setter - def num_threads(self, value): - if isinstance(value, int): - self.strategy.num_threads = value - else: - print("WARNING: num_threads should have value of int type") - @property def auto(self): return self.strategy.auto diff --git a/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py index cc3d1cd2128..2991f80aa53 100644 --- a/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py @@ -18,39 +18,6 @@ from .meta_optimizer_base import MetaOptimizerBase from ..base.private_helper_function import wait_server_ready -def get_build_strategy(dist_strategy): - build_strategy = paddle.BuildStrategy() - build_strategy.enable_sequential_execution = \ - dist_strategy.sequential_execution - build_strategy.remove_unnecessary_lock = True - build_strategy.fuse_elewise_add_act_ops = \ - dist_strategy.fuse_elewise_add_act_ops - build_strategy.fuse_bn_act_ops = \ - dist_strategy.fuse_bn_act_ops - build_strategy.enable_auto_fusion = \ - dist_strategy.enable_auto_fusion - build_strategy.fuse_relu_depthwise_conv = \ - dist_strategy.fuse_relu_depthwise_conv - build_strategy.fuse_broadcast_ops = \ - dist_strategy.fuse_broadcast_ops - build_strategy.sync_batch_norm = \ - dist_strategy.sync_batch_norm - return build_strategy - - -def get_execution_strategy(dist_strategy): - execution_strategy = paddle.ExecutionStrategy() - execution_strategy.num_threads = \ - dist_strategy.num_threads - execution_strategy.num_iteration_per_drop_scope = \ - dist_strategy.num_iteration_per_drop_scope - execution_strategy.num_iteration_per_run = \ - dist_strategy.num_iteration_per_run - execution_strategy.use_thread_barrier = \ - dist_strategy.use_thread_barrier - return execution_strategy - - class GraphExecutionOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(GraphExecutionOptimizer, self).__init__(optimizer) @@ -76,7 +43,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): pass # should fix the variable - def _setup_nccl_op(self, startup_program, main_program): + def _setup_nccl_op(self, startup_program, main_program, build_strategy): trainer_endpoints = self.role_maker.get_trainer_endpoints() trainer_id = self.role_maker.worker_index() current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id] @@ -88,14 +55,14 @@ class GraphExecutionOptimizer(MetaOptimizerBase): wait_server_ready(other_trainer_endpoints) nccl_id_var = startup_program.global_block().create_var( name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) - for i in range(1, self.user_defined_strategy.nccl_comm_num): + for i in range(1, build_strategy.nccl_comm_num): startup_program.global_block().create_var( name="NCCLID_{}".format(i), persistable=True, type=core.VarDesc.VarType.RAW) - if self.user_defined_strategy.hierachical_allreduce: - for i in range(0, self.user_defined_strategy.nccl_comm_num): + if build_strategy.use_hierarchical_allreduce: + for i in range(0, build_strategy.nccl_comm_num): startup_program.global_block().create_var( name="Hierarchical_inter_NCCLID_{}".format(i), persistable=True, @@ -112,48 +79,80 @@ class GraphExecutionOptimizer(MetaOptimizerBase): attrs={ "trainers": trainer_endpoints, "trainer_id": trainer_id, - "nccl_comm_num": self.user_defined_strategy.nccl_comm_num, + "nccl_comm_num": build_strategy.nccl_comm_num, "use_hierarchical_allreduce": - self.user_defined_strategy.hierachical_allreduce, + build_strategy.use_hierarchical_allreduce, "hierarchical_allreduce_inter_ranks": - self.user_defined_strategy.hierachical_allreduce_inter_ranks + build_strategy.hierarchical_allreduce_inter_nranks }) def _try_to_compile(self, startup_program, main_program, loss): - build_strategy = get_build_strategy(self.user_defined_strategy) - exe_strategy = get_execution_strategy(self.user_defined_strategy) + import copy + dist_strategy = self.user_defined_strategy + local_build_strategy = paddle.fluid.BuildStrategy() + local_build_strategy.enable_sequential_execution = \ + dist_strategy.build_strategy.enable_sequential_execution + local_build_strategy.fuse_elewise_add_act_ops = \ + dist_strategy.build_strategy.fuse_elewise_add_act_ops + local_build_strategy.fuse_bn_act_ops = \ + dist_strategy.build_strategy.fuse_bn_act_ops + local_build_strategy.enable_auto_fusion = \ + dist_strategy.build_strategy.enable_auto_fusion + local_build_strategy.fuse_relu_depthwise_conv = \ + dist_strategy.build_strategy.fuse_relu_depthwise_conv + local_build_strategy.fuse_broadcast_ops = \ + dist_strategy.build_strategy.fuse_broadcast_ops + local_build_strategy.fuse_all_optimizer_ops = \ + dist_strategy.build_strategy.fuse_all_optimizer_ops + local_build_strategy.enable_inplace = \ + dist_strategy.build_strategy.enable_inplace + local_build_strategy.use_hierarchical_allreduce = \ + dist_strategy.use_hierarchical_allreduce + local_build_strategy.hierarchical_allreduce_inter_nranks = \ + dist_strategy.hierarchical_allreduce_inter_nranks + local_build_strategy.sync_batch_norm = \ + dist_strategy.sync_batch_norm + local_build_strategy.fuse_all_reduce_ops = \ + dist_strategy.fuse_all_reduce_ops + local_build_strategy.nccl_comm_num = \ + dist_strategy.nccl_comm_num + + exe_strategy = self.user_defined_strategy.execution_strategy node_num = self.role_maker.worker_num() + if self.role_maker._is_collective: assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num if node_num <= 1: # local mode - if self.user_defined_strategy.nccl_comm_num > 1: + if local_build_strategy.nccl_comm_num > 1: logging.warn("set nccl_comm_num=1 since you only have 1 node.") - self.user_defined_strategy.nccl_comm_num = 1 + local_build_strategy.nccl_comm_num = 1 - if self.user_defined_strategy.hierachical_allreduce: + if local_build_strategy.use_hierarchical_allreduce: logging.warn( "set hierachical_allreduce=False since you only have 1 node." ) - self.user_defined_strategy.hierachical_allreduce = False + local_build_strategy.use_hierarchical_allreduce = False - sync_allreduce = self.user_defined_strategy.sync_nccl_allreduce + sync_allreduce = dist_strategy.sync_nccl_allreduce if sync_allreduce: - exe_strategy.num_threads = self.user_defined_strategy.nccl_comm_num + 1 - if self.user_defined_strategy.hierachical_allreduce: - exe_strategy.num_threads = 2 * self.user_defined_strategy.nccl_comm_num + 1 + paddle.fluid.framework.set_flags({ + "FLAGS_sync_nccl_allreduce": True + }) + exe_strategy.num_threads = local_build_strategy.nccl_comm_num + 1 + if local_build_strategy.use_hierarchical_allreduce: + exe_strategy.num_threads = 2 * local_build_strategy.nccl_comm_num + 1 if exe_strategy.num_threads > 4: logging.warn( "if you use hierachical_allreduce or " - "with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0" + "with multi nccl comm, please set distributed_strategy.sync_nccl_allreduce=False" ) - # TODO(guru4elephant): should be an independent optimizer - sync_batch_norm = self.user_defined_strategy.sync_batch_norm + sync_batch_norm = local_build_strategy.sync_batch_norm if sync_batch_norm: - self.user_defined_strategy.nccl_comm_num = 1 - self.user_defined_strategy.hierachical_allreduce = False + local_build_strategy.nccl_comm_num = 1 + local_build_strategy.use_hierarchical_allreduce = False exe_strategy.num_threads = 1 logging.warn( "use sync_batch_norm will hang when set num_threads > 1, so " @@ -161,19 +160,19 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ) # TODO(guru4elephant): should be an independent optimizer - self._setup_nccl_op(startup_program, main_program) + self._setup_nccl_op(startup_program, main_program, local_build_strategy) - build_strategy.num_trainers = self.role_maker.worker_num() - build_strategy.trainer_id = self.role_maker.worker_index() - build_strategy.trainers_endpoints = self.role_maker.get_trainer_endpoints( + local_build_strategy.num_trainers = self.role_maker.worker_num() + local_build_strategy.trainer_id = self.role_maker.worker_index() + local_build_strategy.trainers_endpoints = self.role_maker.get_trainer_endpoints( ) - build_strategy.enable_backward_optimizer_op_deps = True + local_build_strategy.enable_backward_optimizer_op_deps = True self._compiled_program = compiler.CompiledProgram(main_program) self._compiled_program.with_data_parallel( loss_name=loss.name, - build_strategy=build_strategy, + build_strategy=local_build_strategy, exec_strategy=exe_strategy, share_vars_from=None) @@ -188,7 +187,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): startup_program = paddle.default_startup_program() compiled_program = self._try_to_compile(startup_program, loss.block.program, loss) - loss.block.program.graph = compiled_program + loss.block.program._graph = compiled_program # just return self.optimizer_ops and self.param_grads return None, None diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 393ee0682d4..8e6aa43e1ad 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3936,6 +3936,9 @@ class Program(object): # appending gradients times self._appending_grad_times = 0 + # compiled program, i.e. Graph + self._graph = None + def global_seed(self, seed=0): """ Set global seed for Program diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index bac03176c8d..890d716ff0c 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -27,12 +27,18 @@ class TestStrategyConfig(unittest.TestCase): strategy.amp = "True" self.assertEqual(strategy.amp, False) - def test_amp_loss_scaling(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.amp_loss_scaling = 32768 - self.assertEqual(strategy.amp_loss_scaling, 32768) - strategy.amp_loss_scaling = 0.1 - self.assertEqual(strategy.amp_loss_scaling, 32768) + def test_amp_configs(self): + strategy = paddle.fleet.DistributedStrategy() + configs = { + "init_loss_scaling": 32768, + "decr_every_n_nan_or_inf": 2, + "incr_every_n_steps": 1000, + "incr_ratio": 2.0, + "use_dynamic_loss_scaling": True, + "decr_ratio": 0.5 + } + strategy.amp_configs = configs + self.assertEqual(strategy.amp_configs["init_loss_scaling"], 32768) def test_recompute(self): strategy = paddle.fleet.DistributedStrategy() @@ -43,21 +49,12 @@ class TestStrategyConfig(unittest.TestCase): strategy.recompute = "True" self.assertEqual(strategy.recompute, False) - def test_recompute_checkpoints(self): + def test_recompute_configs(self): strategy = paddle.fleet.DistributedStrategy() - strategy.recompute_checkpoints = ["var1", "var2", "var3"] - self.assertEqual(len(strategy.recompute_checkpoints), 3) - import paddle.fluid as fluid - program = fluid.Program() - cur_block = program.current_block() - var1 = cur_block.create_var(name="var4", shape=[1, 1], dtype="int32") - var2 = cur_block.create_var(name="var5", shape=[1, 1], dtype="int32") - var3 = cur_block.create_var(name="var6", shape=[1, 1], dtype="int32") - strategy.recompute_checkpoints = [var1, var2, var3] - self.assertEqual(len(strategy.recompute_checkpoints), 3) - self.assertEqual(strategy.recompute_checkpoints[0], "var4") - strategy.recompute_checkpoints = [var1, "var2", var3] - self.assertEqual(strategy.recompute_checkpoints[1], "var5") + configs = {"checkpoints": ["x", "y"]} + strategy.recompute_configs = configs + self.assertEqual(len(strategy.recompute_configs["checkpoints"]), 2) + print(strategy.recompute_configs) def test_pipeline(self): strategy = paddle.fleet.DistributedStrategy() @@ -68,12 +65,11 @@ class TestStrategyConfig(unittest.TestCase): strategy.pipeline = "True" self.assertEqual(strategy.pipeline, False) - def test_pipeline_micro_batch(self): + def test_pipeline_configs(self): strategy = paddle.fleet.DistributedStrategy() - strategy.pipeline_micro_batch = 1 - self.assertEqual(strategy.pipeline_micro_batch, 1) - strategy.pipeline_micro_batch = 0.1 - self.assertEqual(strategy.pipeline_micro_batch, 1) + configs = {"micro_batch": 4} + strategy.pipeline_configs = configs + self.assertEqual(strategy.pipeline_configs["micro_batch"], 4) def test_localsgd(self): strategy = paddle.fleet.DistributedStrategy() @@ -84,12 +80,11 @@ class TestStrategyConfig(unittest.TestCase): strategy.localsgd = "True" self.assertEqual(strategy.localsgd, False) - def test_localsgd_k_step(self): + def test_localsgd_configs(self): strategy = paddle.fleet.DistributedStrategy() - strategy.localsgd_k_step = 1 - self.assertEqual(strategy.localsgd_k_step, 1) - strategy.localsgd_k_step = "2" - self.assertEqual(strategy.localsgd_k_step, 1) + configs = {"k_steps": 4} + strategy.localsgd_configs = configs + self.assertEqual(strategy.localsgd_configs["k_steps"], 4) def test_dgc(self): strategy = paddle.fleet.DistributedStrategy() @@ -100,21 +95,14 @@ class TestStrategyConfig(unittest.TestCase): strategy.dgc = "True" self.assertEqual(strategy.dgc, False) - def test_hierachical_allreduce(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.hierachical_allreduce = True - self.assertEqual(strategy.hierachical_allreduce, True) - strategy.hierachical_allreduce = False - self.assertEqual(strategy.hierachical_allreduce, False) - strategy.hierachical_allreduce = "True" - self.assertEqual(strategy.hierachical_allreduce, False) - - def test_hierachical_allreduce_inter_ranks(self): + def test_sync_nccl_allreduce(self): strategy = paddle.fleet.DistributedStrategy() - strategy.hierachical_allreduce_inter_ranks = 1 - self.assertEqual(strategy.hierachical_allreduce_inter_ranks, 1) - strategy.hierachical_allreduce_inter_ranks = "2" - self.assertEqual(strategy.hierachical_allreduce_inter_ranks, 1) + strategy.sync_nccl_allreduce = True + self.assertEqual(strategy.sync_nccl_allreduce, True) + strategy.sync_nccl_allreduce = False + self.assertEqual(strategy.sync_nccl_allreduce, False) + strategy.sync_nccl_allreduce = "True" + self.assertEqual(strategy.sync_nccl_allreduce, False) def test_nccl_comm_num(self): strategy = paddle.fleet.DistributedStrategy() @@ -123,6 +111,40 @@ class TestStrategyConfig(unittest.TestCase): strategy.nccl_comm_num = "2" self.assertEqual(strategy.nccl_comm_num, 1) + def test_use_hierarchical_allreduce(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.use_hierarchical_allreduce = True + self.assertEqual(strategy.use_hierarchical_allreduce, True) + strategy.use_hierarchical_allreduce = False + self.assertEqual(strategy.use_hierarchical_allreduce, False) + strategy.use_hierarchical_allreduce = "True" + self.assertEqual(strategy.use_hierarchical_allreduce, False) + + def test_hierarchical_allreduce_inter_nranks(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.hierarchical_allreduce_inter_nranks = 8 + self.assertEqual(strategy.hierarchical_allreduce_inter_nranks, 8) + strategy.hierarchical_allreduce_inter_nranks = "4" + self.assertEqual(strategy.hierarchical_allreduce_inter_nranks, 8) + + def test_sync_batch_norm(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.sync_batch_norm = True + self.assertEqual(strategy.sync_batch_norm, True) + strategy.sync_batch_norm = False + self.assertEqual(strategy.sync_batch_norm, False) + strategy.sync_batch_norm = "True" + self.assertEqual(strategy.sync_batch_norm, False) + + def test_fuse_all_reduce_ops(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.fuse_all_reduce_ops = True + self.assertEqual(strategy.fuse_all_reduce_ops, True) + strategy.fuse_all_reduce_ops = False + self.assertEqual(strategy.fuse_all_reduce_ops, False) + strategy.fuse_all_reduce_ops = "True" + self.assertEqual(strategy.fuse_all_reduce_ops, False) + def test_gradient_merge(self): strategy = paddle.fleet.DistributedStrategy() strategy.gradient_merge = True @@ -132,21 +154,11 @@ class TestStrategyConfig(unittest.TestCase): strategy.gradient_merge = "True" self.assertEqual(strategy.gradient_merge, False) - def test_gradient_merge_k_step(self): + def test_gradient_merge_configs(self): strategy = paddle.fleet.DistributedStrategy() - strategy.gradient_merge_k_step = 1 - self.assertEqual(strategy.gradient_merge_k_step, 1) - strategy.gradient_merge_k_step = "2" - self.assertEqual(strategy.gradient_merge_k_step, 1) - - def test_sequential_execution(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.sequential_execution = True - self.assertEqual(strategy.sequential_execution, True) - strategy.sequential_execution = False - self.assertEqual(strategy.sequential_execution, False) - strategy.sequential_execution = "True" - self.assertEqual(strategy.sequential_execution, False) + configs = {"k_steps": 4} + strategy.gradient_merge_configs = configs + self.assertEqual(strategy.gradient_merge_configs["k_steps"], 4) def test_lars(self): strategy = paddle.fleet.DistributedStrategy() @@ -166,171 +178,20 @@ class TestStrategyConfig(unittest.TestCase): strategy.lamb = "True" self.assertEqual(strategy.lamb, False) - def test_fuse_elewise_add_act_ops(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.fuse_elewise_add_act_ops = True - self.assertEqual(strategy.fuse_elewise_add_act_ops, True) - strategy.fuse_elewise_add_act_ops = False - self.assertEqual(strategy.fuse_elewise_add_act_ops, False) - strategy.fuse_elewise_add_act_ops = "True" - self.assertEqual(strategy.fuse_elewise_add_act_ops, False) - - def test_fuse_bn_act_ops(self): + def test_async_update(self): strategy = paddle.fleet.DistributedStrategy() - strategy.fuse_bn_act_ops = True - self.assertEqual(strategy.fuse_bn_act_ops, True) - strategy.fuse_bn_act_ops = False - self.assertEqual(strategy.fuse_bn_act_ops, False) - strategy.fuse_bn_act_ops = "True" - self.assertEqual(strategy.fuse_bn_act_ops, False) - - def test_enable_auto_fusion(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.enable_auto_fusion = True - self.assertEqual(strategy.enable_auto_fusion, True) - strategy.enable_auto_fusion = False - self.assertEqual(strategy.enable_auto_fusion, False) - strategy.enable_auto_fusion = "True" - self.assertEqual(strategy.enable_auto_fusion, False) - - def test_fuse_relu_depthwise_conv(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.fuse_relu_depthwise_conv = True - self.assertEqual(strategy.fuse_relu_depthwise_conv, True) - strategy.fuse_relu_depthwise_conv = False - self.assertEqual(strategy.fuse_relu_depthwise_conv, False) - strategy.fuse_relu_depthwise_conv = "True" - self.assertEqual(strategy.fuse_relu_depthwise_conv, False) - - def test_enable_inplace(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.enable_inplace = True - self.assertEqual(strategy.enable_inplace, True) - strategy.enable_inplace = False - self.assertEqual(strategy.enable_inplace, False) - strategy.enable_inplace = "True" - self.assertEqual(strategy.enable_inplace, False) + strategy.async_update = True + self.assertEqual(strategy.async_update, True) + strategy.async_update = False + self.assertEqual(strategy.async_update, False) + strategy.async_update = "True" + self.assertEqual(strategy.async_update, False) - def test_fuse_all_reduce_ops(self): + def test_async_configs(self): strategy = paddle.fleet.DistributedStrategy() - strategy.fuse_all_reduce_ops = True - self.assertEqual(strategy.fuse_all_reduce_ops, True) - strategy.fuse_all_reduce_ops = False - self.assertEqual(strategy.fuse_all_reduce_ops, False) - strategy.fuse_all_reduce_ops = "True" - self.assertEqual(strategy.fuse_all_reduce_ops, False) - - def test_num_iteration_per_drop_scope(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.num_iteration_per_drop_scope = 1 - self.assertEqual(strategy.num_iteration_per_drop_scope, 1) - strategy.num_iteration_per_drop_scope = 0.1 - self.assertEqual(strategy.num_iteration_per_drop_scope, 1) - - def test_num_iteration_per_run(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.num_iteration_per_run = 1 - self.assertEqual(strategy.num_iteration_per_run, 1) - strategy.num_iteration_per_run = 0.1 - self.assertEqual(strategy.num_iteration_per_run, 1) - - def test_sync_batch_norm(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.sync_batch_norm = True - self.assertEqual(strategy.sync_batch_norm, True) - strategy.sync_batch_norm = False - self.assertEqual(strategy.sync_batch_norm, False) - strategy.sync_batch_norm = "True" - self.assertEqual(strategy.sync_batch_norm, False) - - def test_fuse_all_optimizer_ops(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.fuse_all_optimizer_ops = True - self.assertEqual(strategy.fuse_all_optimizer_ops, True) - strategy.fuse_all_optimizer_ops = False - self.assertEqual(strategy.fuse_all_optimizer_ops, False) - strategy.fuse_all_optimizer_ops = "True" - self.assertEqual(strategy.fuse_all_optimizer_ops, False) - - def test_sync(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.sync = True - self.assertEqual(strategy.sync, True) - strategy.sync = False - self.assertEqual(strategy.sync, False) - strategy.sync = "True" - self.assertEqual(strategy.sync, False) - - def test_async_k_step(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.async_k_step = 10000 - self.assertEqual(strategy.async_k_step, 10000) - strategy.async_k_step = 0.1 - self.assertEqual(strategy.async_k_step, 10000) - - def test_send_queue_size(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.send_queue_size = 10000 - self.assertEqual(strategy.send_queue_size, 10000) - strategy.send_queue_size = 0.1 - self.assertEqual(strategy.send_queue_size, 10000) - - def test_independent_recv_thread(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.independent_recv_thread = True - self.assertEqual(strategy.independent_recv_thread, True) - strategy.independent_recv_thread = False - self.assertEqual(strategy.independent_recv_thread, False) - strategy.independent_recv_thread = "True" - self.assertEqual(strategy.independent_recv_thread, False) - - def test_min_send_grad_num_before_recv(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.min_send_grad_num_before_recv = 10000 - self.assertEqual(strategy.min_send_grad_num_before_recv, 10000) - strategy.min_send_grad_num_before_recv = 0.1 - self.assertEqual(strategy.min_send_grad_num_before_recv, 10000) - - def test_thread_pool_size(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.thread_pool_size = 10000 - self.assertEqual(strategy.thread_pool_size, 10000) - strategy.thread_pool_size = 0.1 - self.assertEqual(strategy.thread_pool_size, 10000) - - def test_send_wait_times(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.send_wait_times = 10000 - self.assertEqual(strategy.send_wait_times, 10000) - strategy.send_wait_times = 0.1 - self.assertEqual(strategy.send_wait_times, 10000) - - def test_runtime_split_send_recv(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.runtime_split_send_recv = True - self.assertEqual(strategy.runtime_split_send_recv, True) - strategy.runtime_split_send_recv = False - self.assertEqual(strategy.runtime_split_send_recv, False) - strategy.runtime_split_send_recv = "True" - self.assertEqual(strategy.runtime_split_send_recv, False) - - def use_thread_barrier(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.thread_barrier = True - self.assertEqual(strategy.thread_barrier, True) - strategy.thread_barrier = False - self.assertEqual(strategy.thread_barrier, False) - strategy.thread_barrier = "True" - self.assertEqual(strategy.thread_barrier, False) - - def test_enable_backward_optimizer_op_deps(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.enable_backward_optimizer_op_deps = True - self.assertEqual(strategy.enable_backward_optimizer_op_deps, True) - strategy.enable_backward_optimizer_op_deps = False - self.assertEqual(strategy.enable_backward_optimizer_op_deps, False) - strategy.enable_backward_optimizer_op_deps = "True" - self.assertEqual(strategy.enable_backward_optimizer_op_deps, False) + configs = {"k_steps": 1000} + strategy.async_update_configs = configs + self.assertEqual(strategy.async_update_configs["k_steps"], 1000) def test_elastic(self): strategy = paddle.fleet.DistributedStrategy() @@ -350,39 +211,69 @@ class TestStrategyConfig(unittest.TestCase): strategy.auto = "True" self.assertEqual(strategy.auto, False) - def test_sync_nccl_allreduce(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.sync_nccl_allreduce = True - self.assertEqual(strategy.sync_nccl_allreduce, True) - strategy.sync_nccl_allreduce = False - self.assertEqual(strategy.sync_nccl_allreduce, False) - strategy.sync_nccl_allreduce = "True" - self.assertEqual(strategy.sync_nccl_allreduce, False) - - def test_fuse_broadcast_ops(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.fuse_broadcast_ops = True - self.assertEqual(strategy.fuse_broadcast_ops, True) - strategy.fuse_broadcast_ops = False - self.assertEqual(strategy.fuse_broadcast_ops, False) - strategy.fuse_broadcast_ops = "True" - self.assertEqual(strategy.fuse_broadcast_ops, False) - - def test_num_threads(self): - strategy = paddle.fleet.DistributedStrategy() - strategy.num_threads = 1 - self.assertEqual(strategy.num_threads, 1) - strategy.num_threads = 0.1 - self.assertEqual(strategy.num_threads, 1) - def test_strategy_prototxt(self): strategy = paddle.fleet.DistributedStrategy() - strategy.sync_nccl_allreduce = True + strategy.async_update = True + strategy.localsgd = True + strategy.dgc = True + localsgd_configs = {"k_steps": 5} + strategy.localsgd_configs = localsgd_configs + build_strategy = paddle.fluid.BuildStrategy() + build_strategy.enable_sequential_execution = True + build_strategy.nccl_comm_num = 10 + build_strategy.use_hierarchical_allreduce = True + build_strategy.hierarchical_allreduce_inter_nranks = 1 + build_strategy.fuse_elewise_add_act_ops = True + build_strategy.fuse_bn_act_ops = True + build_strategy.enable_auto_fusion = True + build_strategy.fuse_relu_depthwise_conv = True + build_strategy.fuse_broadcast_ops = True + build_strategy.fuse_all_optimizer_ops = True + build_strategy.sync_batch_norm = True + build_strategy.enable_inplace = True + build_strategy.fuse_all_reduce_ops = True + build_strategy.enable_backward_optimizer_op_deps = True + build_strategy.trainers_endpoints = ["1", "2"] + strategy.build_strategy = build_strategy + exe_strategy = paddle.fluid.ExecutionStrategy() + exe_strategy.num_threads = 10 + exe_strategy.num_iteration_per_drop_scope = 10 + exe_strategy.num_iteration_per_run = 10 + strategy.execution_strategy = exe_strategy strategy.save_to_prototxt("dist_strategy.prototxt") strategy2 = paddle.fleet.DistributedStrategy() strategy2.load_from_prototxt("dist_strategy.prototxt") - self.assertEqual(strategy.sync_nccl_allreduce, - strategy2.sync_nccl_allreduce) + self.assertEqual(strategy.dgc, strategy2.dgc) + + def test_build_strategy(self): + build_strategy = paddle.fluid.BuildStrategy() + build_strategy.enable_sequential_execution = True + build_strategy.nccl_comm_num = 10 + build_strategy.use_hierarchical_allreduce = True + build_strategy.hierarchical_allreduce_inter_nranks = 1 + build_strategy.fuse_elewise_add_act_ops = True + build_strategy.fuse_bn_act_ops = True + build_strategy.enable_auto_fusion = True + build_strategy.fuse_relu_depthwise_conv = True + build_strategy.fuse_broadcast_ops = True + build_strategy.fuse_all_optimizer_ops = True + build_strategy.sync_batch_norm = True + build_strategy.enable_inplace = True + build_strategy.fuse_all_reduce_ops = True + build_strategy.enable_backward_optimizer_op_deps = True + build_strategy.trainers_endpoints = ["1", "2"] + + strategy = paddle.fleet.DistributedStrategy() + strategy.build_strategy = build_strategy + + def test_execution_strategy(self): + exe_strategy = paddle.fluid.ExecutionStrategy() + exe_strategy.num_threads = 10 + exe_strategy.num_iteration_per_drop_scope = 10 + exe_strategy.num_iteration_per_run = 10 + + strategy = paddle.fleet.DistributedStrategy() + strategy.execution_strategy = exe_strategy if __name__ == '__main__': -- GitLab