diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index ff05e039f5c2b142ae0793f2f8ec3a4f66d6eed4..788d4e526352a9197bd47f791a2df4e90344739e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -81,7 +81,9 @@ class FP16Utils(object): if not FP16Utils.is_fp32_cast_op(block, op): continue output_name = op.desc.output_arg_names()[0] - param_name = output_name.strip("@GRAD") + param_name = output_name.strip( + "@GRAD@MERGED" + ) if "@MERGED" in output_name else output_name.strip("@GRAD") if param_name not in shard.global_params: raise ValueError("Output 'X' of cast_op must be a grad of" "model param, but {} is not a grad".format( diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 3b0cfe21a79492737926638fafc66b3cb8ba320c..340eff46f7341697a83b1d1f1ead5576cb3dc15b 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -41,7 +41,7 @@ class GradientClipHelper(object): for input_name in op.desc.input_arg_names(): if input_name in deperated_vars: deperate_op = True - param_name = input_name.strip("@GRAD") + param_name = input_name.strip("@GRAD@MERGED") if shard.is_param(param_name) and \ not shard.has_param(param_name): deperate_op = True diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index ce57d812efef237e9726ef3073b9c45e5462394b..12d0bc7394082cbe29b18b5fac871fbe402e5504 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -188,13 +188,6 @@ class ShardingOptimizer(MetaOptimizerBase): # pp_optimizer._rename_gradient_var_name(main_block) # crop ops for idx, op in reversed(list(enumerate(main_block.ops))): - # if op.type == 'fill_constant' and int(op.attr('op_role')) == 16: - # out_name = op.output_arg_names[0] - # if not 'GRAD' in out_name: continue - # param_name = out_name.strip("@GRAD") - # #if main_block.has_var(out_name): continue - # if self._shard.has_param(param_name): continue - # main_block._remove_op(idx) if is_update_op(op): op_role_var = op.attr('op_role_var') param_name = op_role_var[0] @@ -208,13 +201,6 @@ class ShardingOptimizer(MetaOptimizerBase): #if self._shard.has_param(param_name): continue if in_name not in main_block.vars: main_block._remove_op(idx) - #param_list = [] - #for param_name, grad_name in params_grads: - # if self._shard.has_param(param_name): - # param_list.append(param_name) - #pp_optimizer._clear_gradients(main_block, param_list) - #accumulated_grad_names = pp_optimizer._accumulate_gradients( - # main_block) # accumulated_grad_names = sorted(accumulated_grad_names) if self.pp_allreduce_in_optimize: print("persistable FP32 grad: ") @@ -229,149 +215,7 @@ class ShardingOptimizer(MetaOptimizerBase): self._shard, core.op_proto_and_checker_maker.OpRole.Optimize, use_calc_stream=True) - #if not self._shard.has_param(param_name): continue - ##if not main_block.has_var(grad_name): continue - #assert main_block.has_var(grad_name) - #grad_var = main_block.vars[grad_name] - #grad_var.persistable = True - #main_block._insert_op( - # index=0, - # type='fill_constant', - # inputs={}, - # outputs={'Out': [grad_var]}, - # attrs={ - # 'shape': grad_var.shape, - # 'dtype': grad_var.dtype, - # 'value': float(0), - # #self._op_device_key: device, - # # a trick to run this op once per mini-batch - # 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched, - # }) - - #def _create_var(block, ref_var, name): - # """ - # Create a new var for block, which has the same type, - # shape and dtype as ref_var, then rename it with the - # name `name`. - # """ - # new_var = block.create_var( - # name=name, - # shape=ref_var.shape, - # dtype=ref_var.dtype, - # type=ref_var.type, - # lod_level=ref_var.lod_level, - # persistable=ref_var.persistable, - # is_data=ref_var.is_data, - # need_check_feed=ref_var.desc.need_check_feed()) - # new_var.stop_gradient = ref_var.stop_gradient - # return new_var - - #def _rename_arg(op, old_name, new_name): - # op_desc = op.desc - # if isinstance(op_desc, tuple): - # op_desc = op_desc[0] - # op_desc._rename_input(old_name, new_name) - # op_desc._rename_output(old_name, new_name) - - #print("params_grads:", params_grads) - #for param_name, grad_name in params_grads: - # if not self._shard.has_param(param_name): continue - # #if not main_block.has_var(grad_name): continue - # assert main_block.has_var(grad_name) - # use_fp16 = False - # fp16_grad_name = param_name + '.cast_fp16@GRAD' - # if main_block.has_var(grad_name): - # fp16_grad_var = main_block.vars[fp16_grad_name] - # use_fp16 = True - # grad_var = main_block.vars[grad_name] - # if use_fp16: - # cast_grad_var_name = paddle.fluid.unique_name.generate( - # grad_name) - # cast_var = _create_var(main_block, fp16_grad_var, - # cast_grad_var_name) - # cast_var.persistable = False - # main_block.append_op( - # #index=offset + 1, - # type='cast', - # inputs={'X': grad_var}, - # outputs={'Out': cast_var}, - # attrs={ - # 'in_dtype': grad_var.dtype, - # 'out_dtype': cast_var.dtype, - # 'op_role': - # core.op_proto_and_checker_maker.OpRole.Backward, - # }) - # #offset += 1 - # main_block.append_op( - # #index=offset + 1, - # type='sum', - # inputs={'X': [fp16_grad_var, cast_var]}, - # outputs={'Out': fp16_grad_var}, - # attrs={ - # 'op_role': - # core.op_proto_and_checker_maker.OpRole.Backward, - # 'op_role_var': op_role_var - # }) - - # for index, op in reversed(tuple(enumerate(list(main_block.ops)))): - # offset = index - # if is_backward_op(op) and ( - # 'op_role_var' in op.attr_names): - # op_role_var = op.all_attrs()['op_role_var'] - - # if len(op_role_var) == 0: - # continue - # assert len(op_role_var) % 2 == 0 - # offset = index - # for i in range(0, len(op_role_var), 2): - # grad_name = op_role_var[i + 1] - # if not main_block.has_var(grad_name): continue - # grad_var = main_block.vars[grad_name] - # if not 'cast_fp16' in grad_name: - # new_grad_var_name = paddle.fluid.unique_name.generate(grad_name) - # new_var = _create_var(main_block, grad_var, - # new_grad_var_name) - # new_var.persistable = False - # _rename_arg(op, grad_name, new_grad_var_name) - # main_block._insert_op( - # index=offset + 1, - # type='sum', - # inputs={'X': [grad_var, new_var]}, - # outputs={'Out': grad_var}, - # attrs={ - # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, - # 'op_role_var': op_role_var - # }) - # offset += 1 - # if 'cast_fp16' in grad_name: - # param_name = op_role_var[i] - # fp32_grad_var_name = param_name + "@GRAD" - # fp32_grad_var = main_block.vars[grad_name] - # cast_grad_var_name = paddle.fluid.unique_name.generate( - # fp32_grad_var_name) - # cast_var = _create_var(main_block, grad_var, - # cast_grad_var_name) - # cast_var.persistable = False - # main_block._insert_op( - # index=offset + 1, - # type='cast', - # inputs={'X': fp32_grad_var}, - # outputs={'Out': cast_var}, - # attrs={ - # 'in_dtype': fp32_grad_var.dtype, - # 'out_dtype': cast_var.dtype, - # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, - # # self._op_role_var_key: op_role_var - # }) - # offset += 1 - # main_block._insert_op( - # index=offset + 1, - # type='sum', - # inputs={'X': [grad_var, cast_var]}, - # outputs={'Out': grad_var}, - # attrs={ - # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, - # 'op_role_var': op_role_var}) + main_block._sync_with_cpp() # TODO(wangxi): add optimize offload @@ -699,6 +543,17 @@ class ShardingOptimizer(MetaOptimizerBase): for idx in range(len(self._segments)): assert len(self._segments[idx]._allreduce_vars) == 0 + # fix the _end_idx for segments[-1] if pp is used. + new_end_idx = self._segments[-1]._end_idx + for idx in range(self._segments[-1]._end_idx - 1, + self._segments[-1]._start_idx - 1, -1): + op = block.ops[idx] + if op.type == "fill_constant" or op.type == "sum": + if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1 + elif op.type == "cast": + if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1 + self._segments[-1]._end_idx = new_end_idx + if self._segments[-1]._allreduce_vars: shard_allredue_vars = self._shard.filter_grads(self._segments[-1] ._allreduce_vars) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 3630cfd0eafeb916df6ed0f722755fd925085a0e..bbea0ed0662f52c8baf12db46504fad47bebb063 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4293,7 +4293,7 @@ class PipelineOptimizer(object): input_name = op.input_arg_names[0] output_name = op.output_arg_names[0] if '@Fetch' in output_name: - post_op = self._find_real_post_op(block.ops, op, output_name) + post_op = self._find_post_op(block.ops, op, output_name) op._set_attr('op_device', post_op.attr('op_device')) else: prev_op = self._find_real_prev_op(block.ops, op, @@ -4849,6 +4849,9 @@ class PipelineOptimizer(object): Create a new merged gradient for each parameter and accumulate the corresponding gradient to it. """ + merged_gradient_names = [] + first_opt_op_idx = None + for index, op in reversed(tuple(enumerate(list(block.ops)))): # remove the cast op of fp16 grad to fp32 grad if self._is_optimize_op(op) and op.type == 'cast': @@ -4859,6 +4862,9 @@ class PipelineOptimizer(object): block._remove_op(index) continue + if self._is_backward_op(op) and not first_opt_op_idx: + first_opt_op_idx = index + 1 + if self._is_backward_op(op) and ( self._op_role_var_key in op.attr_names): op_role_var = op.attr(self._op_role_var_key) @@ -4868,7 +4874,7 @@ class PipelineOptimizer(object): assert len(op_role_var) % 2 == 0 op._remove_attr(self._op_role_var_key) for i in range(0, len(op_role_var), 2): - offset = 1 + offset = 0 param_name = op_role_var[i] assert block.has_var(param_name), ( "parameter {} not in " @@ -4886,7 +4892,7 @@ class PipelineOptimizer(object): merged_param_grad_var = block.var(merged_param_grad_name) merged_param_grad_var.persistable = True block._insert_op( - index=index + offset, + index=first_opt_op_idx + offset, type='fill_constant', inputs={}, outputs={'Out': [merged_param_grad_var]}, @@ -4902,7 +4908,7 @@ class PipelineOptimizer(object): grad_var = block.vars[grad_name] if not 'cast_fp16' in grad_name: block._insert_op( - index=index + offset, + index=first_opt_op_idx + offset, type='sum', inputs={'X': [grad_var, merged_param_grad_var]}, outputs={'Out': merged_param_grad_var}, @@ -4918,7 +4924,7 @@ class PipelineOptimizer(object): cast_grad_var_name) cast_grad_var.persistable = False block._insert_op( - index=index + offset, + index=first_opt_op_idx + offset, type='cast', inputs={'X': grad_var}, outputs={'Out': cast_grad_var}, @@ -4929,7 +4935,7 @@ class PipelineOptimizer(object): }) offset += 1 block._insert_op( - index=index + offset, + index=first_opt_op_idx + offset, type='sum', inputs={ 'X': [merged_param_grad_var, cast_grad_var] @@ -5705,10 +5711,10 @@ class RecomputeOptimizer(Optimizer): for output_var in output_vars: if output_var in need_offload_checkpoint_names: - assert len( - output_vars - ) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( - output_var, op) + #assert len( + # output_vars + #) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( + # output_var, op) if output_var in self.un_offload_checkpoint_names: # insert sync op if last checkpoint has not been sync @@ -5733,14 +5739,14 @@ class RecomputeOptimizer(Optimizer): format(output_var)) # need to sync the last need to offload checkpoint before the last checkpoint as output op if output_var == last_checkpoint: - assert len( - output_vars - ) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( - output_var, op) - assert last_offload_checkpoint == self.sorted_checkpoint_names[ - -2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format( - last_checkpoint, self.sorted_checkpoint_names[-2], - last_offload_checkpoint) + #assert len( + # output_vars + #) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( + # output_var, op) + #assert last_offload_checkpoint == self.sorted_checkpoint_names[ + # -2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format( + # last_checkpoint, self.sorted_checkpoint_names[-2], + # last_offload_checkpoint) # sync if last checkpoint has not been sync if self.checkpoint_usage_count_and_idx[ last_offload_checkpoint]['idx'] == 0: