From 1e60a0c4b4f044036c8f9bd95482ec110ac8e8c6 Mon Sep 17 00:00:00 2001 From: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> Date: Wed, 7 Apr 2021 18:24:16 +0800 Subject: [PATCH] [3D-parallelism] Hybrid Model Parallelism (#32074) --- .../framework/distributed_strategy.proto | 18 +- .../meta_optimizers/pipeline_optimizer.py | 5 + .../meta_optimizers/sharding/fp16_helper.py | 66 ++- .../sharding/gradient_clip_helper.py | 81 ++- .../sharding/offload_helper.py | 281 ++++++++++ .../fleet/meta_optimizers/sharding/prune.py | 4 + .../fleet/meta_optimizers/sharding/utils.py | 66 ++- .../meta_optimizers/sharding_optimizer.py | 517 ++++++++++++++---- python/paddle/fluid/backward.py | 15 + python/paddle/fluid/optimizer.py | 56 +- .../test_fleet_sharding_meta_optimizer.py | 66 ++- 11 files changed, 1023 insertions(+), 152 deletions(-) mode change 100644 => 100755 python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py create mode 100755 python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py mode change 100644 => 100755 python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 805ef1c3e91..6363eedc80a 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -29,14 +29,18 @@ message RecomputeConfig { } message ShardingConfig { - optional float segment_broadcast_MB = 1 [ default = 32.0 ]; - optional bool hybrid_dp = 2 [ default = false ]; - optional int32 sharding_degree = 3 [ default = 8 ]; - optional int32 mp_degree = 4 [ default = 1 ]; - optional string sharding_segment_strategy = 5 + optional string sharding_segment_strategy = 1 [ default = 'segment_broadcast_MB' ]; - repeated string segment_anchors = 6; - optional int32 gradient_merge_acc_step = 7 [ default = 1 ]; + optional float segment_broadcast_MB = 2 [ default = 32.0 ]; + repeated string segment_anchors = 3; + optional int32 sharding_degree = 4 [ default = 8 ]; + optional int32 mp_degree = 5 [ default = 1 ]; + optional int32 dp_degree = 6 [ default = 1 ]; + optional bool hybrid_dp = 7 [ default = false ]; + optional int32 gradient_merge_acc_step = 8 [ default = 1 ]; + optional bool optimize_offload = 9 [ default = false ]; + optional bool pp_allreduce_in_optimize = 10 [ default = false ]; + optional int32 pp_degree = 11 [ default = 1 ]; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py old mode 100644 new mode 100755 index 6cb7593b6bf..ae2daa9b9d8 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase): 'accumulate_steps'] self.schedule_mode = user_defined_strategy.pipeline_configs[ 'schedule_mode'] + self.use_sharding = user_defined_strategy.sharding def _can_apply(self): if not self.role_maker._is_collective: return False + # FIXME revise for hybrid parallelism + if self.use_sharding: + return False + if self.user_defined_strategy.pipeline == True: return True return False diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index cf399f66946..40ba7781566 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -81,7 +81,10 @@ class FP16Utils(object): if not FP16Utils.is_fp32_cast_op(block, op): continue output_name = op.desc.output_arg_names()[0] - param_name = output_name.strip("@GRAD") + # TODO (JZ-LIANG) revise this for uniform mixed parallelism + param_name = output_name.strip( + "@GRAD@MERGED" + ) if "@MERGED" in output_name else output_name.strip("@GRAD") if param_name not in shard.global_params: raise ValueError("Output 'X' of cast_op must be a grad of" "model param, but {} is not a grad".format( @@ -105,7 +108,11 @@ class FP16Utils(object): reversed_x = [] reversed_x_paramname = [] for input_name in op.desc.input('X'): - param_name = input_name.strip("@GRAD") + # TODO (JZ-LIANG) revise this for uniform mixed parallelism + if "@MERGED" in input_name: + param_name = input_name.strip("@GRAD@MERGED") + else: + param_name = input_name.strip("@GRAD") if param_name not in shard.global_params: raise ValueError( "Input 'X' of check_finite_and_unscale must" @@ -169,3 +176,58 @@ class FP16Utils(object): OP_ROLE_KEY: OpRole.Optimize }) block._sync_with_cpp() + + # TODO (JZ-LIANG) revise this for uniform mixed parallelism + @staticmethod + def sync_amp_check_nan_inf(block, ring_id): + update_loss_scaling_op_idx = -1 + + for idx, op in reversed(list(enumerate(block.ops))): + if op.type == "update_loss_scaling": + update_loss_scaling_op_idx = idx + inf_var_name = op.desc.input('FoundInfinite')[0] + op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD") + + # not use amp + if update_loss_scaling_op_idx == -1: + return + inf_var = block.var(inf_var_name) + inf_var_int32 = block.create_var( + name=inf_var_name + "@cast_int32", + shape=inf_var.shape, + dtype=core.VarDesc.VarType.INT32) + inf_var_global = block.create_var( + name=inf_var_name + "@GLOBAL_WORLD", + shape=inf_var.shape, + dtype=inf_var.dtype) + block._insert_op_without_sync( + update_loss_scaling_op_idx, + type='cast', + inputs={'X': inf_var}, + outputs={'Out': inf_var_int32}, + attrs={ + "in_dtype": inf_var.dtype, + "out_dtype": inf_var_int32.dtype, + OP_ROLE_KEY: OpRole.Optimize + }) + block._insert_op_without_sync( + update_loss_scaling_op_idx + 1, + type='c_allreduce_max', + inputs={'X': inf_var_int32}, + outputs={'Out': inf_var_int32}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) + block._insert_op_without_sync( + update_loss_scaling_op_idx + 2, + type='cast', + inputs={'X': inf_var_int32}, + outputs={'Out': inf_var_global}, + attrs={ + "in_dtype": inf_var_int32.dtype, + "out_dtype": inf_var_global.dtype, + OP_ROLE_KEY: OpRole.Optimize + }) + block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 5082bc33830..d5a012b147a 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -32,6 +32,7 @@ class GradientClipHelper(object): deperated_vars = set() deperate_op_idx = set() reversed_x_paramname = [] + global_norm_sum_op_idx = -1 for idx, op in enumerate(block.ops): if not self._is_gradient_clip_op(op): continue @@ -41,7 +42,11 @@ class GradientClipHelper(object): for input_name in op.desc.input_arg_names(): if input_name in deperated_vars: deperate_op = True - param_name = input_name.strip("@GRAD") + # TODO (JZ-LIANG) revise this for uniform mixed parallelism + if "@MERGED" in input_name: + param_name = input_name.strip("@GRAD@MERGED") + else: + param_name = input_name.strip("@GRAD") if shard.is_param(param_name) and \ not shard.has_param(param_name): deperate_op = True @@ -51,7 +56,8 @@ class GradientClipHelper(object): if deperate_op: deperate_op_idx.add(idx) for output_name in op.desc.output_arg_names(): - deperated_vars.add(output_name) + if output_name not in op.desc.input_arg_names(): + deperated_vars.add(output_name) if not deperated_vars: # got no gradient_clip op @@ -65,6 +71,7 @@ class GradientClipHelper(object): continue reversed_inputs = [] if op.type == "sum": + global_norm_sum_op_idx = idx for input_name in op.desc.input_arg_names(): if input_name not in deperated_vars: reversed_inputs.append(input_name) @@ -86,20 +93,20 @@ class GradientClipHelper(object): OP_ROLE_KEY: OpRole.Optimize, }) - # global norm should only be sum within each model parallelism word size when use global group - if pure_dp_degree > 1: - block._insert_op_without_sync( - idx + 2, - type='scale', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={ - 'scale': 1.0 / float(pure_dp_degree), - 'op_namescope': "/gradient_clip_model_parallelism", - 'bias': 0.0, - 'bias_after_scale': False, - OP_ROLE_KEY: OpRole.Optimize - }) + # global norm should only be sum within each model parallelism word size when use global group + if pure_dp_degree > 1: + block._insert_op_without_sync( + idx + 2, + type='scale', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={ + 'scale': 1.0 / float(pure_dp_degree), + 'op_namescope': "/gradient_clip_model_parallelism", + 'bias': 0.0, + 'bias_after_scale': False, + OP_ROLE_KEY: OpRole.Optimize + }) # the grad sum here should take the all and only param in the current shard to_check_param = set(reversed_x_paramname) @@ -115,3 +122,45 @@ class GradientClipHelper(object): block._remove_var(var_name, sync=False) block._sync_with_cpp() return + + # TODO (JZ-LIANG) revise this for uniform mixed parallelism + def sync_global_norm(self, block, ring_id, pure_dp_degree=1): + """ + prune gradient_clip related ops for params that not belong to cur shard + prune: square, reduce_sum, elementwise_mul + keep: sum, sqrt, elementwise_max, elementwise_div + """ + for idx, op in reversed(list(enumerate(block.ops))): + if not self._is_gradient_clip_op(op): + continue + + if op.type == "sum": + sum_res = op.desc.output_arg_names()[0] + block._insert_op_without_sync( + idx + 1, + type='c_allreduce_sum', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={ + 'ring_id': ring_id, + 'op_namescope': "/gradient_clip_model_parallelism", + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }) + + # global norm should only be sum within each model parallelism word size + if pure_dp_degree > 1: + block._insert_op_without_sync( + idx + 2, + type='scale', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={ + 'scale': 1.0 / float(pure_dp_degree), + 'op_namescope': "/gradient_clip_model_parallelism", + 'bias': 0.0, + 'bias_after_scale': False, + OP_ROLE_KEY: OpRole.Optimize + }) + + return diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py new file mode 100755 index 00000000000..76803818453 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -0,0 +1,281 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole +from paddle.fluid import core, unique_name + + +class OffloadHelper(object): + cpu_place_type = 0 + cuda_place_type = 1 + cuda_pinned_place_type = 2 + + def __init__(self): + pass + "0: dst is on CPUPlace. " + "1: dst is on CUDAPlace. " + "2: dst is on CUDAPinnedPlace. " + + def _insert_cast_op(self, block, idx, src_name, dst_name): + src_var = block.var(src_name) + if not block.has_var(dst_name): + block.create_var( + name=dst_name, + shape=src_var.shape, + dtype=core.VarDesc.VarType.FP16, + persistable=True) + dst_var = block.var(dst_name) + assert dst_var.dtype == core.VarDesc.VarType.FP16 + block._insert_op_without_sync( + idx, + type='cast', + inputs={'X': src_var}, + outputs={'Out': dst_var}, + attrs={ + 'in_dtype': src_var.dtype, + 'out_dtype': dst_var.dtype, + OP_ROLE_KEY: OpRole.Optimize + }) + + def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type): + src_var = block.var(src_name) + dst_var = block.var(dst_name) + block._insert_op_without_sync( + idx, + type='memcpy', + inputs={'X': src_var}, + outputs={'Out': dst_var}, + attrs={ + 'dst_place_type': dst_place_type, + OP_ROLE_KEY: OpRole.Optimize, + }) + + def _insert_fetch_op(self, block, idx, src_name, dst_name): + self._insert_memcpy_op(block, idx, src_name, dst_name, + OffloadHelper.cuda_place_type) + + def _insert_offload_op(self, block, idx, src_name, dst_name): + self._insert_memcpy_op(block, idx, src_name, dst_name, + OffloadHelper.cuda_pinned_place_type) + + def _get_offload_var_name(self, name): + return unique_name.generate(name + '@offload') + + def _create_offload_var(self, var_name, offload_var_name, blocks): + for block in blocks: + var = block.var(var_name) + var.persistable = False + offload_var = block.create_var( + name=offload_var_name, + shape=var.shape, + dtype=var.dtype, + persistable=True) + + def offload_fp32param(self, block, startup_block): + """ + (p_fp16) = cast(p) + (p_fp16_recompute) = cast(p) + (pout,) = adam(p) + ===========================> + rename(p_fp16_recompute, p_fp16) + + (p,) = prefetch(p@offload) + (pout,) = adam(p) + (p_fp16) = cast(p) + (p@offload) = memcpy(p) + """ + param_to_idx = dict() + param_to_fp16 = dict() + # recompute_var which need rename to fp16_param + fp16_param_to_recompute = dict() + recompute_to_fp16 = dict() + + def remove_param(input_name): + param_to_idx.pop(input_name) + if input_name in param_to_fp16: + fp16_param = param_to_fp16.pop(input_name) + if fp16_param in fp16_param_to_recompute: + recompute = fp16_param_to_recompute.pop(fp16_param) + recompute_to_fp16.pop(recompute) + + # step1: record param + for idx, op in reversed(list(enumerate(block.ops))): + if op.type in ('adam', 'momentum', 'lars', 'lamb'): + param = op.desc.input("Param")[0] + param_to_idx[param] = idx + + # step2: remove param which can't offload + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + break + for input_name in op.desc.input_arg_names(): + if input_name not in param_to_idx: + continue + + # param is real used by fp32 op + if op.type != 'cast': + remove_param(input_name) + continue + + # param is only used by cast op, + # which to cast fp32_param to fp16_param + output_name = op.output_arg_names[0] + if 'cast_fp16' not in output_name: + remove_param(input_name) + continue + + if 'subprog' not in output_name: + assert output_name == input_name + '.cast_fp16' + assert input_name not in param_to_fp16, \ + "There must be only one cast op from fp32 param to fp16 param." + param_to_fp16[input_name] = output_name + else: + # fp16-->recompute_var + assert input_name in param_to_fp16, \ + "param must first be cast to fp16" + fp16_param = param_to_fp16[input_name] + fp16_param_to_recompute[fp16_param] = output_name + recompute_to_fp16[output_name] = fp16_param + + param_name_to_offload_name = dict() + # step3: main_block add offload, cast op + # change recompute to fp16, remove cast(param) to fp16 + for idx, op in reversed(list(enumerate(block.ops))): + if op.type in ('adam', 'momentum', 'lars', 'lamb'): + param = op.desc.input("Param")[0] + if param not in param_to_idx: continue + # step3.1: create offload_var + offload_var_name = self._get_offload_var_name(param) + param_name_to_offload_name[param] = offload_var_name + self._create_offload_var(param, offload_var_name, + [block, startup_block]) + + # step3.2: insert cast op and offload op + self._insert_offload_op(block, idx + 1, param, offload_var_name) + + assert param in param_to_fp16 + fp16_param_name = param_to_fp16[param] + fp16_param_var = block.var(fp16_param_name) + fp16_param_var.persistable = True + self._insert_cast_op(block, idx + 1, param, + param_to_fp16[param]) + + # step3.3: insert fetch op + self._insert_fetch_op(block, idx, offload_var_name, param) + continue + + # step3.4: remove cast op + if op.type == 'cast': + input_name = op.desc.input_arg_names()[0] + if input_name in param_to_idx: + block._remove_op(idx, sync=False) + continue + + # step3.5: change recompute_param to fp16_param + for input_name in op.desc.input_arg_names(): + if input_name in recompute_to_fp16: + op._rename_input(input_name, recompute_to_fp16[input_name]) + for output_name in op.desc.output_arg_names(): + if output_name in recompute_to_fp16: + op._rename_output(output_name, + recompute_to_fp16[output_name]) + + # step4: remove recompute_param + for name in recompute_to_fp16.keys(): + block._remove_var(name, sync=False) + + # step5: startup_block add offload + visited_vars = set() + for idx, op in reversed(list(enumerate(startup_block.ops))): + for out_name in op.output_arg_names: + if out_name in visited_vars: + continue + + if out_name in param_name_to_offload_name: + var_name = out_name + offload_var_name = param_name_to_offload_name[var_name] + self._insert_offload_op(startup_block, idx + 1, var_name, + offload_var_name) + self._insert_cast_op(startup_block, idx + 1, var_name, + param_to_fp16[var_name]) + + visited_vars.add(out_name) + + block._sync_with_cpp() + startup_block._sync_with_cpp() + + def offload(self, block, startup_block): + """ + (m1, m2) = prefetch(m1@offload, m2@offload) + (m1out, m2out, pout) = adam(m1, m2, p) + (m1@offload, m2@offload) = memcpy(m1, m2) + """ + vars_name_to_offload_name = dict() + + # main_block add offload + for idx, op in reversed(list(enumerate(block.ops))): + if not is_optimizer_op(op): + break + + vars_name = [] + if op.type == "adam": + # {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} = + # adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']}) + vars_name.append(op.desc.input("Moment1")[0]) + vars_name.append(op.desc.input("Moment2")[0]) + elif op.type == 'momentum': + pass + elif op.type == 'lars': + pass + elif op.type == 'lamb': + pass + + # step1: create and init offload_var + for var_name in vars_name: + assert var_name not in vars_name_to_offload_name + + offload_var_name = self._get_offload_var_name(var_name) + vars_name_to_offload_name[var_name] = offload_var_name + + self._create_offload_var(var_name, offload_var_name, + [block, startup_block]) + + # step2: insert offload op + for var_name in vars_name: + offload_var_name = vars_name_to_offload_name[var_name] + self._insert_offload_op(block, idx + 1, var_name, + offload_var_name) + + # step3: insert fetch op + for var_name in vars_name: + offload_var_name = vars_name_to_offload_name[var_name] + self._insert_fetch_op(block, idx, offload_var_name, var_name) + + # startup_block add offload + visited_vars = set() + for idx, op in reversed(list(enumerate(startup_block.ops))): + for out_name in op.output_arg_names: + if out_name in visited_vars: + continue + + if out_name in vars_name_to_offload_name: + var_name = out_name + offload_var_name = vars_name_to_offload_name[var_name] + # insert offload op after var is generated + self._insert_offload_op(startup_block, idx + 1, var_name, + offload_var_name) + visited_vars.add(out_name) + + block._sync_with_cpp() + startup_block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py old mode 100644 new mode 100755 index 70753b59ccc..5a43367cf1a --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py @@ -126,6 +126,10 @@ class ProgramDeps(object): def should_remove_op(self, op_idx): op = self._block.ops[op_idx] + # TODO (JZ-LIANG) revise this for uniform mixed parallelism + # remove check_finite_and_unscale op if its input 'X' is empty + if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0: + return True for output_name in op.desc.output_arg_names(): if output_name not in self._should_removed_var: return False diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index 8b111026bdb..f4ceb2d287a 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): """ insert sync_comm_op for vars """ + # NOTE (JZ-LIANG) to be check, may result undefined case + if len(comm_dep_vars) == 0: + return 0 + op_role = get_valid_op_role(block, insert_idx) block._insert_op_without_sync( insert_idx, @@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops): return -def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): +def insert_allreduce_ops(block, + insert_idx, + ring_id, + allreduce_vars, + op_role=OpRole.Backward, + use_calc_stream=False): """ _add_allreduce_ops """ + if len(allreduce_vars) == 0: + return + for var in allreduce_vars: block._insert_op_without_sync( insert_idx, type='c_allreduce_sum', inputs={'X': var}, outputs={'Out': var}, - attrs={'ring_id': ring_id, - OP_ROLE_KEY: OpRole.Backward}) + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': use_calc_stream, + OP_ROLE_KEY: op_role + }) return -def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): +def insert_reduce_ops(block, + insert_idx, + ring_id, + reduce_vars, + shard, + op_role=OpRole.Backward, + use_calc_stream=False): """ _add_allreduce_ops """ for var in reduce_vars: + root_id = get_grad_device(var, shard) assert root_id >= 0, "root id should be a positive int".format(var) block._insert_op_without_sync( @@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): attrs={ 'ring_id': ring_id, 'root_id': root_id, - OP_ROLE_KEY: OpRole.Backward + 'use_calc_stream': use_calc_stream, + OP_ROLE_KEY: op_role }) - return +def get_grad_device(grad_name, shard): + assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format( + grad_name) + base_name = None + # mind the traversal order + possible_suffixes = [ + '.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD' + ] + for suffix in possible_suffixes: + if suffix in grad_name: + base_name = re.sub(suffix, '', grad_name) + break + + assert base_name in shard.global_param2device, "[{}] should be a param variable.".format( + base_name) + + return shard.global_param2device[base_name] + + +def get_first_check_finite_and_unscale_op_idx(block): + + for idx, op in enumerate(block.ops): + if op.type == "check_finite_and_unscale": + return idx + + raise ValueError("check_finite_and_unscale does not exist in block") + + def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): """ _add_broadcast_ops @@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0): outputs={'Out': loss_grad_var}, attrs={'scale': scale, OP_ROLE_KEY: OpRole.Backward}) + break def comm_analyse(main_program): @@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None): and part of persistable vars are duplicated and exist in all the ranks with different values. This function handles the model saving for sharding training. """ + # TODO (JZ-LIANG) revise this for uniform mixed parallelism + if main_program._pipeline_opt: + main_program = main_program._pipeline_opt['section_program']['program'] def is_opt_vars(var): # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index cf3f75740ee..a83ae226a9d 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -16,16 +16,16 @@ import paddle from paddle.fluid import unique_name, core import paddle.fluid as fluid from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper -from paddle.distributed.fleet.meta_optimizers.common import is_backward_op +from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper +from .sharding.offload_helper import OffloadHelper from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.utils import * from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard - from paddle.fluid import layers import logging @@ -38,6 +38,8 @@ __all__ = ["ShardingOptimizer"] class ShardingOptimizer(MetaOptimizerBase): + """Sharding Optimizer.""" + def __init__(self, optimizer): super(ShardingOptimizer, self).__init__(optimizer) self.inner_opt = optimizer @@ -46,7 +48,8 @@ class ShardingOptimizer(MetaOptimizerBase): "AMPOptimizer", "LarsOptimizer", "LambOptimizer", - "ModelParallelOptimizer", + # "ModelParallelOptimizer", + # "PipelineOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self._main_program = None @@ -88,26 +91,6 @@ class ShardingOptimizer(MetaOptimizerBase): self._nrings_sharding = 1 self._nrings_dp = 1 - # parallelism - self.sharding_degree = int(self.user_defined_strategy.sharding_configs[ - "sharding_degree"]) - assert self.sharding_degree > 1, "sharding degree must be larger than zero" - self.mp_degree = int(self.user_defined_strategy.sharding_configs[ - "mp_degree"]) - self.hybrid_dp = self.user_defined_strategy.sharding_configs[ - "hybrid_dp"] - - self.pp_degree = 1 - - # dp here is the pure dp as the outest parallelism - self.dp_degree = int(self.role_maker._worker_num() // self.mp_degree // - self.sharding_degree) - assert self.role_maker._worker_num( - ) == self.dp_degree * self.mp_degree * self.sharding_degree * self.pp_degree - if self.hybrid_dp: - assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format( - self.dp_degree) - # segment self._sharding_segment_strategy = str( self.user_defined_strategy.sharding_configs[ @@ -128,55 +111,231 @@ class ShardingOptimizer(MetaOptimizerBase): "the sharding segment strategy [{}] is not implemented".format( str(self._sharding_segment_strategy))) + # parallelism + self.sharding_degree = int(self.user_defined_strategy.sharding_configs[ + "sharding_degree"]) + assert self.sharding_degree > 0, "sharding degree must be larger than zero" + self.mp_degree = int(self.user_defined_strategy.sharding_configs[ + "mp_degree"]) + # pipeline setting + # TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline + self.pp_degree = int(self.user_defined_strategy.sharding_configs[ + "pp_degree"]) + if self.pp_degree > 1: + assert self.user_defined_strategy.pipeline == True + + self.dp_degree = int(self.user_defined_strategy.sharding_configs[ + 'dp_degree']) + assert self.role_maker._worker_num( + ) == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format( + self.role_maker._worker_num(), + self.mp_degree, + self.sharding_degree, + self.pp_degree, + self.dp_degree, ) + + self.hybrid_dp = self.user_defined_strategy.sharding_configs[ + "hybrid_dp"] + # NOTE (JZ-LIANG) + # there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline]. + # we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance: + # sharding: communication within node, and therefore should insert within backward segment to overlap with bw calc, conduct every micro step + # pipeline: communication accross nodes, and therefore should insert in update segemnt, conduct just once per global step + self.hybrid_dp_mode = None + # dp here is the pure dp as the outest parallelism + if self.hybrid_dp: + assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format( + self.dp_degree) + if self.pp_degree > 1: + self.hybrid_dp_mode = "pp_hybrid_dp" + else: + assert self.sharding_degree > 1, "by now we only support five kind of hybrid dp: sharding_hybrid_dp, mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp." + self.hybrid_dp_mode = "sharding_hybrid_dp" + # gradient merge self._gradient_merge_acc_step = int( self.user_defined_strategy.sharding_configs[ "gradient_merge_acc_step"]) - self._grad2merged_grad = dict() + self.gradient_merge_mode = None + if self.pp_degree <= 1: + self.gradient_merge_mode = "sharding_gm" + self._grad2merged_grad = dict() + else: + self.gradient_merge_mode = "pp_gm" + self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[ + 'accumulate_steps'] + if self._gradient_merge_acc_step > 1: + logging.info("Gradient merge in [{}], acc step = [{}]".format( + self.gradient_merge_mode, self._gradient_merge_acc_step)) + + # optimize offload + self.optimize_offload = self.user_defined_strategy.sharding_configs[ + "optimize_offload"] + + # this feature is design for ascend, and should NOT be used in GPU training + self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ + "pp_allreduce_in_optimize"] if self.inner_opt is None: raise ValueError( "self.inner_opt of ShardingOptimizer should not be None.") - optimize_ops, params_grads = self.inner_opt.minimize( - loss, startup_program, parameter_list, no_grad_set) + + if self.pp_degree > 1: + pp_optimizer = fluid.optimizer.PipelineOptimizer( + self.inner_opt, self._gradient_merge_acc_step) + main_program = loss.block.program + main_program._pipeline_opt = dict() + self.schedule_mode = self.user_defined_strategy.pipeline_configs[ + 'schedule_mode'] + main_program._pipeline_opt['schedule_mode'] = self.schedule_mode + main_program._pipeline_opt[ + 'micro_batch_size'] = self.user_defined_strategy.pipeline_configs[ + 'micro_batch_size'] + self.pp_rank_ = self.role_maker._worker_index() // ( + self.sharding_degree * self.mp_degree) % self.pp_degree + main_program._pipeline_opt['local_rank'] = self.pp_rank_ + main_program._pipeline_opt[ + 'global_rank'] = self.role_maker._worker_index() + main_program._pipeline_opt['use_sharding'] = True + # TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline + main_program._pipeline_opt['ring_id'] = 20 + main_program._pipeline_opt['global_ring_id'] = 3 + + optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize( + loss, startup_program, parameter_list, no_grad_set) + self.pp_degree = len(program_list) + else: + optimize_ops, params_grads = self.inner_opt.minimize( + loss, startup_program, parameter_list, no_grad_set) if startup_program is None: startup_program = default_startup_program() - main_block = loss.block + + if self.pp_degree > 1: + startup_program = startup_program._pipeline_opt['startup_program'] + #main_program = main_program._pipeline_opt['section_program']['program'] + print("pp_rank:", self.pp_rank_) + main_program = program_list[self.pp_rank_] + with open("main_%d" % self.role_maker._worker_index(), 'w') as f: + f.writelines(str(main_program)) + main_block = main_program.global_block() + new_params_grads = [] + for param, grad in params_grads: + if main_block.has_var(param.name): + new_params_grads.append((param, grad)) + params_grads = new_params_grads + + else: + main_block = loss.block + startup_block = startup_program.global_block() self._main_program = main_block.program self._startup_program = startup_program + if self.pp_degree > 1: + pp_optimizer._rename_gradient_var_name(main_block) + with open("main_%d" % self.role_maker._worker_index(), 'w') as f: + f.writelines(str(main_program)) + # step0: _init_comm self._init_comm() - # step1: _build_shard - self._build_shard(params_grads) - - # step2: split_program - self._split_program(main_block) + if self.sharding_degree > 1: - # step3: add broadcast and reduce ops - self._add_broadcast_allreduce(main_block) - main_block._sync_with_cpp() - startup_block._sync_with_cpp() + # step1: build shard + self._build_shard(params_grads) + + # step2: split_program + self._split_program(main_block) + + # step3: add broadcast and reduce ops + self._add_broadcast_allreduce(main_block) + main_block._sync_with_cpp() + startup_block._sync_with_cpp() + + main_block._sync_with_cpp() + + # step4: remove unneeded ops and vars from block + self._prune_main_program(main_block) + self._prune_startup_program(startup_block) + + if self.pp_degree > 1: + # sharding-pp related logic + # pp_optimizer._rename_gradient_var_name(main_block) + # crop ops + if self.sharding_degree > 1: + for idx, op in reversed(list(enumerate(main_block.ops))): + if is_update_op(op): + op_role_var = op.attr('op_role_var') + param_name = op_role_var[0] + if not self._shard.has_param(param_name): + main_block._remove_op(idx) + + for idx, op in reversed(list(enumerate(main_block.ops))): + if op.type != 'cast': continue + in_name = op.input_arg_names[0] + if in_name not in self._params: continue + #if self._shard.has_param(param_name): continue + if in_name not in main_block.vars: + main_block._remove_op(idx) + + accumulated_grad_names = pp_optimizer._accumulate_gradients( + main_block) + # accumulated_grad_names = sorted(accumulated_grad_names) + if self.pp_allreduce_in_optimize: + print("persistable FP32 grad: ") + print(accumulated_grad_names) + first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( + main_block) + insert_reduce_ops( + main_block, + first_optimize_op_index, + self.sharding_ring_id, + accumulated_grad_names, + self._shard, + core.op_proto_and_checker_maker.OpRole.Optimize, + use_calc_stream=True) + if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": + first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( + main_block) + insert_allreduce_ops( + main_block, + first_optimize_op_index, + self.dp_ring_id, + accumulated_grad_names, + core.op_proto_and_checker_maker.OpRole.Optimize, + use_calc_stream=True) + + # if not use sharding, adapt amp/clip, for remain parallelism. + # cast --> amp --> clip --> opt + if self.sharding_degree <= 1: + # amp + FP16Utils.sync_amp_check_nan_inf(main_block, self.global_ring_id) + + # clip + gradientclip_helper = GradientClipHelper(self.global_ring_id) + gradientclip_helper.sync_global_norm( + main_block, self.global_ring_id, self.dp_degree) - # step4: scale the loss by the num of dp degree - # sharding is also a senario of dp - scale_ = self.dp_degree * self.sharding_degree - if scale_ > 1: - insert_scale_loss_grad_ops(main_block, scale=1.0 / scale_) + # step6: loss div dp_degree + global_dp_degree = self.sharding_degree * self.dp_degree + assert int(global_dp_degree) == global_dp_degree + if global_dp_degree > 1: + insert_scale_loss_grad_ops(main_block, scale=1.0 / global_dp_degree) main_block._sync_with_cpp() - # step5: remove unneeded ops and vars from block - self._prune_main_program(main_block) - self._prune_startup_program(startup_block) - if self.hybrid_dp: - self._initialization_broadcast(startup_program) - - # step6: optional gradient merge - if self._gradient_merge_acc_step > 1: + # TODO(wangxi): add optimize offload + # opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100) + # sync its memcpy could not be overlap with calc, otherwise it will slower down training severely. + if self.optimize_offload: + logging.info("Sharding with optimize offload !") + offload_helper = OffloadHelper() + offload_helper.offload(main_block, startup_block) + offload_helper.offload_fp32param(main_block, startup_block) + + # step6: (optional) sharding gradient merge + if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: self._sharding_gradient_merge(main_block) # # check op dependecy @@ -184,14 +343,29 @@ class ShardingOptimizer(MetaOptimizerBase): # check_broadcast(main_block) # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, # self.dp_ring_id) + + if self.hybrid_dp: + # NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp + # init param broadcast should be called after startup pruning + self._initialization_broadcast(startup_block) + + with open("start_sharding_%d" % self.role_maker._worker_index(), + 'w') as f: + f.writelines(str(startup_block.program)) + with open("main_sharding_%d" % self.role_maker._worker_index(), + 'w') as f: + f.writelines(str(main_block.program)) + self._wait() return optimize_ops, params_grads def _init_comm(self): + # config sharding & dp groups - self._build_group() + self._build_groups() + # sync var startup_block = self._startup_program.global_block() self.startup_prog_sync_var = startup_block.create_var( name="startup_prog_sync_var", @@ -199,7 +373,7 @@ class ShardingOptimizer(MetaOptimizerBase): dtype=core.VarDesc.VarType.INT32, persistable=False) - # global + # global ring self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, @@ -212,7 +386,7 @@ class ShardingOptimizer(MetaOptimizerBase): append_naive_sync(startup_block, self.startup_prog_sync_var, self.global_ring_id) - # mp + # mp ring if self.mp_degree > 1: self._collective_helper._init_communicator( self._startup_program, @@ -226,7 +400,7 @@ class ShardingOptimizer(MetaOptimizerBase): append_naive_sync(startup_block, self.startup_prog_sync_var, self.global_ring_id) - # sharding + # sharding ring if self.sharding_degree > 1: self._collective_helper._init_communicator( self._startup_program, @@ -240,7 +414,65 @@ class ShardingOptimizer(MetaOptimizerBase): append_naive_sync(startup_block, self.startup_prog_sync_var, self.global_ring_id) - # dp + # pp ring + if self.pp_degree > 1: + if self.schedule_mode == 'F-then-B': # GPipe + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + self.pp_group_endpoints, + self.pp_rank, + self.pp_ring_id, + False, + global_ring_id=self.global_ring_id, + sync=False) + # append_naive_sync(startup_block, self.startup_prog_sync_var, + # self.global_ring_id) + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + self.pp_group_endpoints, + self.pp_rank, + self.pp_ring_id + 2, + False, + global_ring_id=self.global_ring_id, + sync=False) + # append_naive_sync(startup_block, self.startup_prog_sync_var, + # self.global_ring_id) + else: + assert self.schedule_mode == '1F1B' + for pair in self.pipeline_pair: + pair_key = pair[0] * 1000 + pair[1] + ring_id = self.pp_ring_map[pair_key] + print("pp pair:{}, ring_id: {}".format(pair, ring_id)) + if self.pp_rank not in pair: continue + pp_group_endpoints = [ + self.pp_group_endpoints[pair[0]], + self.pp_group_endpoints[pair[1]], + ] + if pair[0] < pair[1]: + start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 + else: + start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[ + 1] - 1 + pp_rank = 0 if self.pp_rank == pair[0] else 1 + self._collective_helper._init_communicator( + self._startup_program, + self.current_endpoint, + pp_group_endpoints, + pp_rank, + ring_id, + False, + global_ring_id=self.global_ring_id, + sync=False) + # append_naive_sync(startup_block, self.startup_prog_sync_var, + # self.global_ring_id) + + # TODO (JZ-LIANG) to unify this shit + assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( + self.pp_rank_, self.pp_rank) + + # pure dp ring if self.dp_degree > 1: self._collective_helper._init_communicator( self._startup_program, @@ -360,17 +592,22 @@ class ShardingOptimizer(MetaOptimizerBase): self._main_program.global_block().var(input_name)) # find reduce vars - if is_backward_op(op) and \ - OP_ROLE_VAR_KEY in op.attr_names: - op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] - if len(op_role_var) != 0: - assert len(op_role_var) % 2 == 0 - for i in range(0, len(op_role_var), 2): - param, reduced_grad = op_role_var[i], op_role_var[i + 1] - segment._allreduce_vars.append(reduced_grad) - assert ( - reduced_grad not in self._reduced_grads_to_param) - self._reduced_grads_to_param[reduced_grad] = param + if self.pp_degree > 1 and self.pp_allreduce_in_optimize: + # place pipeline gradient allreduce in optimize + pass + else: + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + if len(op_role_var) != 0: + assert len(op_role_var) % 2 == 0 + for i in range(0, len(op_role_var), 2): + param, reduced_grad = op_role_var[i], op_role_var[ + i + 1] + segment._allreduce_vars.append(reduced_grad) + assert (reduced_grad not in + self._reduced_grads_to_param) + self._reduced_grads_to_param[reduced_grad] = param # find cast op if FP16Utils.is_fp16_cast_op(block, op, self._params): @@ -462,8 +699,13 @@ class ShardingOptimizer(MetaOptimizerBase): # Prune for idx, op in reversed(list(enumerate(block.ops))): if op.type in [ - "c_allreduce_sum", "c_sync_comm_stream", - "c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init" + "c_allreduce_sum", + "c_sync_comm_stream", + "c_calc_comm_stream", + "c_gen_nccl_id", + "c_comm_init", + 'send_v2', + 'recv_v2', ]: pass elif op.type == "conditional_block": @@ -500,6 +742,16 @@ class ShardingOptimizer(MetaOptimizerBase): if program_deps.should_remove_op(idx): program_deps.remove_op(idx) + # NOTE (JZ-LIANG) revise and unify logic here + # sharding support fp16_allreduce logic + block._sync_with_cpp() + for idx, op in reversed(list(enumerate(block.ops))): + if op.type == 'concat' and is_optimizer_op(op): + # remove inputs that not on this card + reserved_x = [] + for var_name in op.desc.input("X"): + if block.has_var(var_name): reserved_x.append(var_name) + op.desc.set_input('X', reserved_x) block._sync_with_cpp() return @@ -507,21 +759,41 @@ class ShardingOptimizer(MetaOptimizerBase): """ add broadcast allreduce op if enable gradient_merge, insert related ops + + if combined with pipeline(grad accumulate), + the grad allreduce should be done in optimize role """ if len(self._segments) < 1: return # sharding + if self.pp_degree > 1 and self.pp_allreduce_in_optimize: + for idx in range(len(self._segments)): + assert len(self._segments[idx]._allreduce_vars) == 0 + + # NOTE (JZ-LIANG) revise and unify logic here + # fix the _end_idx for segments[-1] if pp is used. + new_end_idx = self._segments[-1]._end_idx + for idx in range(self._segments[-1]._end_idx - 1, + self._segments[-1]._start_idx - 1, -1): + op = block.ops[idx] + if op.type == "fill_constant" or op.type == "sum": + if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1 + elif op.type == "cast": + if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1 + self._segments[-1]._end_idx = new_end_idx + if self._segments[-1]._allreduce_vars: shard_allredue_vars = self._shard.filter_grads(self._segments[-1] ._allreduce_vars) - if self._gradient_merge_acc_step <= 1: - if self.hybrid_dp and len(shard_allredue_vars) >= 1: + if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: + if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( + shard_allredue_vars) >= 1: insert_sync_comm_ops(block, self._segments[-1]._end_idx, self.dp_ring_id, shard_allredue_vars) insert_allreduce_ops(block, self._segments[-1]._end_idx, self.dp_ring_id, shard_allredue_vars) # gradient merge - else: + elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: self.create_persistable_gradients_and_insert_merge_ops( block, self._startup_program.global_block(), @@ -532,9 +804,14 @@ class ShardingOptimizer(MetaOptimizerBase): self.sharding_ring_id, self._segments[-1]._allreduce_vars) # allreduce --> reduce - insert_reduce_ops(block, self._segments[-1]._end_idx, - self.sharding_ring_id, - self._segments[-1]._allreduce_vars, self._shard) + insert_reduce_ops( + block, + self._segments[-1]._end_idx, + self.sharding_ring_id, + self._segments[-1]._allreduce_vars, + self._shard, + op_role=OpRole.Backward, + use_calc_stream=False) for idx, segment in reversed(list(enumerate(self._segments))): allreduce_vars = self._segments[ @@ -574,8 +851,9 @@ class ShardingOptimizer(MetaOptimizerBase): # step2: add Sync ops shard_allredue_vars = self._shard.filter_grads(allreduce_vars) - if self._gradient_merge_acc_step <= 1: - if self.hybrid_dp and len(shard_allredue_vars) >= 1: + if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: + if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( + shard_allredue_vars) >= 1: insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id, shard_allredue_vars) @@ -593,7 +871,7 @@ class ShardingOptimizer(MetaOptimizerBase): self.sharding_ring_id, comm_dep_vars) # gradient merge - else: + elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: broad_cast_vars = [x[0] for x in broadcast_vars] if len(broad_cast_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, @@ -616,7 +894,7 @@ class ShardingOptimizer(MetaOptimizerBase): # step5: add broadcast ops # gradient merge - if self._gradient_merge_acc_step > 1: + if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: self.create_persistable_gradients_and_insert_merge_ops( block, self._startup_program.global_block(), segment._start_idx, @@ -627,20 +905,29 @@ class ShardingOptimizer(MetaOptimizerBase): # step6: add all_reduce ops # dp - if self._gradient_merge_acc_step <= 1: - if self.hybrid_dp and len(shard_allredue_vars) >= 1: + if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: + if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( + shard_allredue_vars) >= 1: insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id, shard_allredue_vars) insert_sync_comm_ops(block, segment._start_idx, self.sharding_ring_id, allreduce_vars) # gradient merge - else: + elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: insert_sync_comm_ops(block, segment._start_idx, self.sharding_ring_id, allreduce_vars) # sharding # allreduce --> reduce - insert_reduce_ops(block, segment._start_idx, self.sharding_ring_id, - allreduce_vars, self._shard) + # TODO temp change + if len(allreduce_vars) > 0: + insert_reduce_ops( + block, + segment._start_idx, + self.sharding_ring_id, + allreduce_vars, + self._shard, + op_role=OpRole.Backward, + use_calc_stream=False) block._sync_with_cpp() @@ -691,14 +978,14 @@ class ShardingOptimizer(MetaOptimizerBase): block._remove_var(var_name, sync=False) block._sync_with_cpp() - def _build_group(self): + def _build_groups(self): """ pre-assign ring ids - mp: 0 - sharding: 1 - pure-dp: 2 - global: 3 - pp: >= 20 + mp: 0 + sharding: 1 + pure-dp: 2 + global: 3 + pp: >= 20 if one parallelism is not enable: -1 and only support parallelism hierarchy: mp --> sharding --> pp --> dp """ @@ -768,6 +1055,30 @@ class ShardingOptimizer(MetaOptimizerBase): self.sharding_group_id = -1 self.sharding_group_endpoints = [] + # pp + if self.pp_degree > 1: + self.pp_ring_id = 20 + self.pp_rank = self.global_rank // (self.sharding_degree * + self.mp_degree) % self.pp_degree + # (NOTE): Already adjust for (outter-pure) dp + self.pp_group_id = self.global_rank // ( + self.mp_degree * self.sharding_degree * self.pp_degree) + pp_first_stage_idx = self.global_rank % ( + self.sharding_degree * self.mp_degree) + self.pp_group_id * ( + self.mp_degree * self.sharding_degree * self.pp_degree) + pp_stage_offset = self.sharding_degree * self.mp_degree + self.pp_group_endpoints = [] + for i in range(self.pp_degree): + self.pp_group_endpoints.append(self.global_endpoints[ + pp_first_stage_idx + pp_stage_offset * i]) + assert self.current_endpoint in self.pp_group_endpoints + else: + self.pp_degree = 1 + self.pp_ring_id = -1 + self.pp_rank = -1 + self.pp_group_id = -1 + self.pp_group_endpoints = [] + # outter-pure-dp group # NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism # e.g. mp-sharding-pp-dp @@ -775,6 +1086,7 @@ class ShardingOptimizer(MetaOptimizerBase): assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format( self.mp_degree, self.sharding_degree, self.pp_degree, self.dp_degree, self.global_word_size) + if self.dp_degree > 1: self.dp_ring_id = 2 self.dp_rank = self.global_rank // (self.sharding_degree * @@ -794,6 +1106,8 @@ class ShardingOptimizer(MetaOptimizerBase): self.dp_group_endpoints = [] # global group + # use for gen_nccl_comm_sync, amp check nan inf, clip by global norm + # NOTE (JZ-LIANG) when use global ring for calc global norm and dp_degree > 1, the allreduce result should be devided by dp_degree self.global_ring_id = 3 logging.info("global word size: {}".format(self.global_word_size)) @@ -817,25 +1131,31 @@ class ShardingOptimizer(MetaOptimizerBase): logging.info("sharding ring id: {}".format(self.sharding_ring_id)) logging.info("#####" * 6) - logging.info("outter pure dp group size: {}".format(self.dp_degree)) - logging.info("outter pure dp rank: {}".format(self.dp_rank)) - logging.info("outter pure dp group endpoints: {}".format( + logging.info("pp group size: {}".format(self.pp_degree)) + logging.info("pp rank: {}".format(self.pp_rank)) + logging.info("pp group id: {}".format(self.pp_group_id)) + logging.info("pp group endpoints: {}".format(self.pp_group_endpoints)) + logging.info("pp ring id: {}".format(self.pp_ring_id)) + logging.info("#####" * 6) + + logging.info("pure dp group size: {}".format(self.dp_degree)) + logging.info("pure dp rank: {}".format(self.dp_rank)) + logging.info("pure dp group endpoints: {}".format( self.dp_group_endpoints)) - logging.info("outter pure dp ring id: {}".format(self.dp_ring_id)) + logging.info("pure dp ring id: {}".format(self.dp_ring_id)) logging.info("#####" * 6) return - def _initialization_broadcast(self, startup_prog): + def _initialization_broadcast(self, startup_block): """ this funtion is to ensure the initialization between dp group to be identical when hybrid-dp is used. """ - block = startup_prog.global_block() params = [] - for param in block.iter_parameters(): + for param in startup_block.iter_parameters(): params.append(param) - block.append_op( + startup_block.append_op( type='c_broadcast', inputs={'X': param}, outputs={'Out': param}, @@ -844,15 +1164,14 @@ class ShardingOptimizer(MetaOptimizerBase): 'root': 0, OP_ROLE_KEY: OpRole.Forward }) - block.append_op( + startup_block.append_op( type='c_sync_comm_stream', inputs={'X': params}, outputs={'Out': params}, attrs={'ring_id': self.dp_ring_id, OP_ROLE_KEY: OpRole.Forward}) - # sync within global group - append_naive_sync(block, self.startup_prog_sync_var, + append_naive_sync(startup_block, self.startup_prog_sync_var, self.global_ring_id) # sharding gradient merge diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index b3a1834d49d..572ebb26d73 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) new_op_desc._set_attr(op_role_attr_name, backward) + if desc.has_attr('op_device'): + new_op_desc._set_attr('op_device', desc.attr('op_device')) result_descs.append(new_op_desc) return result_descs @@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block): new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) new_op_desc._set_attr(op_role_attr_name, backward) + if desc.has_attr('op_device'): + new_op_desc._set_attr('op_device', desc.attr('op_device')) result_descs.append(new_op_desc) return result_descs @@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_( vars_in_memory = vars_should_be_hold + checkpoints_name max_calculated_op_position = len(ops) + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() if recompute_segments == []: gap_ops = ops[0:max_calculated_op_position] for op in reversed(gap_ops): @@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_( _pretty_op_desc_(op.desc, "with_sub_block")) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + # Set device for grad_op according to forward Op + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) added_descs = _add_descs_to_block(grad_op_desc, local_block) grad_op_descs.extend(added_descs) grad_to_var.update(op_grad_to_var) @@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_( _pretty_op_desc_(op.desc, "with_sub_block")) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + # Set device for grad_op according to forward Op + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) added_descs = _add_descs_to_block(grad_op_desc, local_block) grad_op_descs.extend(added_descs) grad_to_var.update(op_grad_to_var) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 76c5a309103..27ce44a257e 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4033,6 +4033,12 @@ class PipelineOptimizer(object): """ Find the post op that has variable named var_name as input. """ + # bugfix for uniform hybrid parallelism + if '.cast_fp32' in var_name: + var_name = var_name.replace('.cast_fp32', '') + if '.cast_fp16' in var_name: + var_name = var_name.replace('.cast_fp16', '') + post_ops = self.input_var_to_op[var_name] if post_ops == None: return None result_op = None @@ -4114,7 +4120,23 @@ class PipelineOptimizer(object): # For LRSched ops, we should put them on all sub-programs to # make sure each sub-program update the lr correctly op._set_attr(self._op_device_key, "gpu:all") - elif op.type == "scale" and self._is_backward_op(op): + # bugfix in hybrid parallelism + elif op.type == "sum" and self._is_backward_op(op): + # For sum ops that compute the sum of @RENAMED@ vars + for name in op.desc.input_arg_names(): + assert '@RENAME@' in name, \ + "The op must be sum used to accumulate renamed vars." + assert len(op.desc.output_arg_names()) == 1 + out_name = op.desc.output_arg_names()[0] + post_op = self._find_post_op(idx, out_name) + assert post_op.has_attr( + 'op_device'), "{} has no op_device attr for var {}".format( + post_op.type, out_name) + device = post_op.attr(self._op_device_key) + assert device, "The post op must have op_device set." + op._set_attr(self._op_device_key, device) + elif (op.type == "cast" or + op.type == "scale") and self._is_backward_op(op): prev_op = self._find_prev_op(idx, op.desc.input("X")[0]) op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key)) elif op.type == "memcpy" and not self._is_optimize_op(op): @@ -4249,11 +4271,19 @@ class PipelineOptimizer(object): Insert a pair of send and recv ops for every two consecutive ops on different devices. """ - extra_index_info = {'index': 0} - # A map from var to device where op takes it as input, # avoiding multiple send and recv ops. input_var_to_device = dict() + # bugfix hybrid parallelism + first_optimize_index = None + for index, op in enumerate(list(block.ops)): + if self._is_optimize_op(op): + first_optimize_index = index + break + extra_index_info = { + 'index': 0, + 'first_optimize_index': first_optimize_index + } for index, op in enumerate(list(block.ops)): cur_device = op.attr(self._op_device_key) @@ -4371,17 +4401,26 @@ class PipelineOptimizer(object): 'peer': 1, }) extra_index_info['index'] += 1 + insert_index = None + if int(op_role) == int(self._op_role.Backward): + insert_index = extra_index_info[ + 'first_optimize_index'] + new_op_role = self._op_role.Optimize + else: + insert_index = index + new_op_role = self._op_role.Backward block._insert_op( - index=index + extra_index_info['index'], + index=insert_index + extra_index_info['index'], type='c_sync_comm_stream', inputs={'X': [var]}, outputs={'Out': [var]}, attrs={ self._op_device_key: prev_dev, - self._op_role_key: self._op_role.Backward, + self._op_role_key: new_op_role, 'ring_id': ring_id, }) - extra_index_info['index'] += 1 + if int(op_role) == int(self._op_role.Forward): + extra_index_info['index'] += 1 var_shape = list(var.shape) var_shape[0] = self.micro_batch_size if var_shape[ 0] < 0 else var_shape[0] @@ -4768,8 +4807,9 @@ class PipelineOptimizer(object): # Step4: Special Case: process persistable vars that exist in # multiple sections - self._process_persistable_vars_in_multi_sections( - main_program, startup_program, program_list) + # FIXME + # self._process_persistable_vars_in_multi_sections( + # main_program, startup_program, program_list) # Step5: Add sub blocks for section programs self._add_sub_blocks(main_block, program_list) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 4d6744f2b6f..f28bf89ff5c 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): "segment_broadcast_MB": 0.2, "segment_anchors": None, "sharding_degree": 2, + "dp_degree": 2, "hybrid_dp": True, "gradient_merge_acc_step": 1, "mp_degree": 1 @@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): "segment_broadcast_MB": 0.2, "segment_anchors": None, "sharding_degree": 2, + "dp_degree": 2, "hybrid_dp": True, "gradient_merge_acc_step": 4, "mp_degree": 1 @@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): fw_bw_ops = [op.type for op in train_prog.blocks[0].ops] opt_ops = [op.type for op in train_prog.blocks[2].ops] self.assertEqual(fw_bw_ops, [ - 'fill_constant', 'fill_constant', 'fill_constant', - 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', - 'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', - 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', - 'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad', - 'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', - 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', - 'tanh_grad', 'elementwise_add_grad', 'mul_grad', - 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', - 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', - 'c_sync_comm_stream', 'elementwise_add', 'elementwise_add', - 'elementwise_add', 'increment', 'elementwise_mod', 'equal', - 'conditional_block' + 'fill_constant', + 'fill_constant', + 'fill_constant', + 'c_sync_calc_stream', + 'c_broadcast', + 'c_broadcast', + 'c_broadcast', + 'c_broadcast', + 'c_broadcast', + 'c_broadcast', + 'c_sync_comm_stream', + 'mul', + 'elementwise_add', + 'tanh', + 'mul', + 'elementwise_add', + 'tanh', + 'mul', + 'elementwise_add', + 'softmax', + 'cross_entropy2', + 'mean', + 'fill_constant', + 'scale', + 'mean_grad', + 'cross_entropy_grad2', + 'softmax_grad', + 'elementwise_add_grad', + 'mul_grad', + 'tanh_grad', + 'elementwise_add_grad', + 'mul_grad', + 'tanh_grad', + 'elementwise_add_grad', + 'mul_grad', + 'c_sync_calc_stream', + 'c_reduce_sum', + 'c_reduce_sum', + 'c_reduce_sum', + 'c_reduce_sum', + 'c_reduce_sum', + 'c_reduce_sum', + 'c_sync_comm_stream', + 'elementwise_add', + 'elementwise_add', + 'elementwise_add', + 'increment', + 'elementwise_mod', + 'equal', + 'conditional_block', ]) self.assertEqual(opt_ops, [ 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale', -- GitLab