diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index b03858119296e419d5eeffe436016fb5b1c050c1..34c46e446af2731422d5d9b0fcc5baab86d2f377 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -822,6 +822,28 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): # TODO to add attribute for moment var op = ops[idx] if int(op.attr('op_role')) == int(OpRole.Optimize): + if op.type == "clip_by_norm": + + param_grad = vars[op.input("X")[0]] + param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program( + param_grad) + assert param_grad_dist_attr is not None + ref_process_mesh = param_grad_dist_attr.process_mesh + ref_dims_mapping = param_grad_dist_attr.dims_mapping + + out = vars[op.output("Out")[0]] + out_dist_attr = TensorDistributedAttribute() + out_dist_attr.process_mesh = ref_process_mesh + out_dist_attr.dims_mapping = ref_dims_mapping + dist_context.set_tensor_dist_attr_for_program(out, + out_dist_attr) + + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = ref_process_mesh + op_dist_attr.set_input_dist_attr(param_grad.name, + param_grad_dist_attr) + op_dist_attr.set_output_dist_attr(out.name, out_dist_attr) + dist_context.set_op_dist_attr_for_program(op, op_dist_attr) if "Grad" in op.input_names and "Param" in ops[idx].input_names: assert len(op.input( diff --git a/python/paddle/distributed/auto_parallel/dist_attribute.py b/python/paddle/distributed/auto_parallel/dist_attribute.py index 4415448769d01ce52b73c102d87edeb6b8517d80..b27cd7a37c95626584194ae7bd619ab16a0e5ea7 100644 --- a/python/paddle/distributed/auto_parallel/dist_attribute.py +++ b/python/paddle/distributed/auto_parallel/dist_attribute.py @@ -21,7 +21,9 @@ _g_tensor_dist_attr_field_keys = [ "process_mesh", "dims_mapping", "shard_sizes", "device_placement" ] -_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"] +_g_op_dist_attr_field_keys = [ + "process_mesh", "impl_type", "impl_idx", "is_recompute" +] _g_op_input_suffix = "@input" @@ -178,6 +180,7 @@ class OperatorDistributedAttribute: self._inputs_dist_attrs = {} self._outputs_dist_attrs = {} self._is_annotated = {} + self._is_recompute = False @property def process_mesh(self): @@ -214,6 +217,15 @@ class OperatorDistributedAttribute: if impl_idx is not None: self._impl_idx = impl_idx + @property + def is_recompute(self): + return self._is_recompute + + @is_recompute.setter + def is_recompute(self, is_recompute): + assert isinstance(is_recompute, bool) + self._is_recompute = is_recompute + @property def inputs_dist_attrs(self): return self._inputs_dist_attrs diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index b194bcc3de6b5bcbf8b26876e7c1d8c3b4b0b2b6..d3bf9e22db4387012d7d562da7ec4cc1b4a5b35c 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -166,6 +166,13 @@ class DistributedContext: else: return None + def get_tensor_dist_attr_for_program_with_id(self, tensor_id): + dist_tensor = self._dist_tensors_for_program.get(tensor_id, None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None + def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr): dist_tensor = DistributedTensor(serial_tensor, dist_attr) self.add_dist_tensor_for_program(dist_tensor) @@ -192,6 +199,13 @@ class DistributedContext: else: return None + def get_op_dist_attr_for_program_with_id(self, op_id): + dist_op = self._dist_ops_for_program.get(op_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None + def set_op_dist_attr_for_program(self, serial_op, dist_attr): dist_op = DistributedOperator(serial_op, dist_attr) self.add_dist_op_for_program(dist_op) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index ef595e2a00f2e1976add2e7b67ac644d6b09cf00..e2de876f01cade351523cbb75af6b1aea44dade1 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -99,6 +99,8 @@ class DistributedOperator: self._dist_attr.impl_type = "default" if self._dist_attr.impl_idx is None: self._dist_attr.impl_idx = -2 + if self._dist_attr.is_recompute is None: + self._dist_attr.is_recompute = False def _filter_dist_attr(self, dist_attr): if dist_attr is None: diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 8f1ba33f544fb35e2935dcf0d178f6c7e86cdd48..505e29282b87068bd822388270065f6d1ddbd12b 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -118,6 +118,8 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): def is_parameter_related(varname, block): + if ".subprog_" in varname: + varname = varname[:varname.index(".subprog_")] if ".cast_fp" in varname: varname = varname[:varname.index(".cast_fp")] assert block.has_var(varname) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 1a3d57bf140ddd12966e187e02a65ac89b2741e9..a98ec89a5099a79301de6865b8b2830a091121f5 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -216,6 +216,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and is_parameter_related( varname, main_block): + # NOTE: When amp and recompute pass are effective at the same time, + # if a parameter is casted and recomputed, the 'parameter@GARD' can not + # be found in the grad_op's output. + if "subprog_" in varname: + varname = varname[:varname.index(".subprog_")] + assert len( backward_op.desc.input(input_name) ) == 1, "parameter input to grad op should be length 1, but got [{}]".format( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 866fed1ae60677d526a49719bc9ebb9f0190a6ec..f019f499aa305b8cfde423fd64c1ad232e281656 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -283,7 +283,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): allreduce_op_dist_attr) # param initialization sync - if Weight_var.is_parameter: + if Weight_var.is_parameter and not op_dist_attr.is_recompute: assert Weight_var.name not in dist_op_context.already_init_sync_vars dist_op_context.already_init_sync_vars.add(Weight_var.name) param = startup_block.var(Weight_var.name) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index f4c31c3654c52e2c261b4be26936e758a83c4ac5..b0b185819c58ae911c167e86ea1408631ff0d475 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -680,7 +680,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) # init param sync - if Weight_var.is_parameter: + if Weight_var.is_parameter and not op_dist_attr.is_recompute: _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id) @@ -968,7 +968,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): allreduce_op_dist_attr) # init param sync - if Weight_var.is_parameter: + if Weight_var.is_parameter and not op_dist_attr.is_recompute: _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id) @@ -1383,7 +1383,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) # init param sync - if Weight_var.is_parameter: + if Weight_var.is_parameter and not op_dist_attr.is_recompute: _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id) @@ -1666,7 +1666,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): allreduce_op_dist_attr) # init param sync - if Weight_var.is_parameter: + if Weight_var.is_parameter and not op_dist_attr.is_recompute: _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py index 56782bec0856a79e3971037974110d51c84e719f..eccd2742db03feafafd529b1738f00bbd44a5dac 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py @@ -83,9 +83,9 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): assert 'Out' in kwargs, "output [{}] is not given".format('Out') assert 'LossScaling' in kwargs, "output [{}] is not given".format( 'LossScaling') - assert 'OutGoodSteps' in kwargs, "input [{}] is not given".format( + assert 'OutGoodSteps' in kwargs, "output [{}] is not given".format( 'OutGoodSteps') - assert 'OutBadSteps' in kwargs, "input [{}] is not given".format( + assert 'OutBadSteps' in kwargs, "output [{}] is not given".format( 'OutBadSteps') assert len(kwargs['FoundInfinite']) == 1, \ diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 294a966726d73abda548b8091db7daecf937b147..d6035d02953ac3f2b5ec6d6ffc720f87d30dcc41 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -97,8 +97,8 @@ class AutoParallelizer: if suffix in attr_name: op._remove_attr(attr_name) - def _apply_pre_optimization_passed(self, main_program, startup_program, - loss, params_grads): + def _apply_pre_optimization_passes(self, main_program, startup_program, + loss, params_grads, no_grad_set): # apply amp pass if self._dist_strategy.amp: config = copy.deepcopy(self._dist_strategy.amp_configs) @@ -111,11 +111,14 @@ class AutoParallelizer: # apply recompute pass if self._dist_strategy.recompute: - auto_parallel_recompute_pass = new_pass( - "auto_parallel_recompute_pass", - self._dist_strategy.recompute_configs) - auto_parallel_recompute_pass.apply(main_program, startup_program, - self._pass_context) + config = copy.deepcopy(self._dist_strategy.recompute_configs) + config["dist_context"] = self._dist_context + config["no_grad_set"] = copy.deepcopy(no_grad_set) + config["loss"] = loss + auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", + config) + auto_parallel_recompute_pass.apply( + [main_program], [startup_program], self._pass_context) def _generate_backward(self, main_program, startup_program, loss, parameter_list, no_grad_set, callbacks): @@ -144,7 +147,7 @@ class AutoParallelizer: return optimize_ops - def _apply_post_optimization_passed(self, main_program, startup_program, + def _apply_post_optimization_passes(self, main_program, startup_program, rank, params_grads): if self._dist_strategy.sharding: @@ -188,9 +191,9 @@ class AutoParallelizer: self._parameter_list, self._no_grad_set, self._callbacks) # serial forward pass - self._apply_pre_optimization_passed(completed_main_program, + self._apply_pre_optimization_passes(completed_main_program, serial_startup_program, serial_loss, - params_grads) + params_grads, self._no_grad_set) # Logical partition partitioner = Partitioner(self._dist_context, rank) dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( @@ -207,7 +210,7 @@ class AutoParallelizer: reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) - self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog, + self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog, rank, dist_params_grads) g_process_group_map = None if not relaunch_phase: diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 76a9faa1c8398ebcb85e26ba88889fe9360c46fc..182f6e8b6604a36149c36be9ed6e222bf9673b1c 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di from .dist_attribute import OperatorDistributedAttribute from .process_group import new_process_group from .utils import set_dist_op_desc_original_id -from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op +from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op from .operators.common import BACKWARD_ONLY_DIST_OPS __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] @@ -200,7 +200,8 @@ class Partitioner(object): serial_output_varname] = new_varname # partition op - if is_forward_op(op): + op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) + if is_forward_op(op) or op_dist_attr.is_recompute: kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_forward_impl = _get_dist_op_forward_implement( op, self._dist_context) @@ -380,9 +381,9 @@ def _get_dist_op_backward_implement(backward_op, dist_context, # NOTE trick for dist ops that only have backward implement if backward_op.type in BACKWARD_ONLY_DIST_OPS: op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) - assert op_dist_attr.impl_idx >= 0 - return get_distributed_operator_impl_container( - backward_op.type).get_impl(op_dist_attr.impl_idx) + dist_op = get_distributed_operator_impl_container(backward_op.type) + if dist_op and op_dist_attr.impl_idx >= 0: + return dist_op.get_impl(op_dist_attr.impl_idx) dist_op = get_distributed_operator_impl_container("default") return dist_op.get_impl(0) diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index b0249356eddb14da004bb817ff43c17185da560a..6e6d2a672fd18631c4f0ac7073eaada488b37967 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -26,6 +26,9 @@ from .dist_context import DistributedContext from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from .process_group import new_process_group, ProcessGroup, _g_process_group_map +# NOTE: If op in _g_special_ops, it will not be resharded. +_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] + class AllGatherOpDesc: """ @@ -966,6 +969,17 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] + + def _is_special_op(op): + global _g_special_ops + if op.type in _g_special_ops: + return True + return False + + if _is_special_op(op): + idx += 1 + continue + dist_op = dist_context.get_dist_op_for_program(op) if dist_op is not None: idx_offset = 0 diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 2316f207ffe8e52a8c9c5f57db41e64c3066af66..1867731974f117ef2530531627b116b83dbdbff3 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1005,8 +1005,8 @@ def set_grad_var_shape(program, dist_context): assert op_dist_attr is not None for var_name in op.output_arg_names: - - assert "@GRAD" in var_name + if "@GRAD" not in var_name: + continue forward_var_name = var_name[:var_name.find("@GRAD")] if op.type in [ "c_allreduce_sum", "c_identity", "scale", "cast" @@ -1076,11 +1076,6 @@ def is_backward_op(op): int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) -def is_recompute_op(op): - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) == 9 - - def is_loss_op(op): return OP_ROLE_KEY in op.attr_names and \ int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 06f2efe08a4891f834b8bb767b385b33677634f2..2519c7d1803c29de55d0f27b2e54b24aa94d9070 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -17,6 +17,7 @@ from .fuse_all_reduce import * from .auto_parallel_gradient_merge import * from .auto_parallel_sharding import * from .auto_parallel_amp import * +from .auto_parallel_recompute import * from .cpp_pass import * __all__ = [ diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py new file mode 100644 index 0000000000000000000000000000000000000000..4039f3ed746772a49fd022a99d5ff281aa177a7e --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -0,0 +1,402 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging + +from .pass_base import PassBase, register_pass +from paddle.fluid import core, unique_name +from paddle.fluid import framework as framework +from paddle.fluid.framework import Variable, Operator +from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name +from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_ +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute +from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr, set_dist_op_desc_original_id +from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping + + +class RecomputeState(ProgramStats): + def __init__(self, block, ops): + super(RecomputeState, self).__init__(block=block, ops=ops) + self._block = block + self._ops = ops + self.var_op_deps = {} + + def build_stats(self): + for i, op in enumerate(self._ops): + for name in op.desc.input_arg_names(): + if name in self.var_op_deps: + self.var_op_deps[name]["var_as_input_ops"].extend([i]) + else: + self.var_op_deps[name] = {} + self.var_op_deps[name]["var_as_input_ops"] = [i] + self.var_op_deps[name]["var_as_output_ops"] = [] + + for name in op.desc.output_arg_names(): + if name in self.var_op_deps: + self.var_op_deps[name]["var_as_output_ops"].extend([i]) + else: + self.var_op_deps[name] = {} + self.var_op_deps[name]["var_as_input_ops"] = [] + self.var_op_deps[name]["var_as_output_ops"] = [i] + + def get_recompute_segments(self, checkpoints): + """ get recompute segments from checkpoints """ + segments = [] + start_idx = -1 + pre_segment_end_idx = -1 + while start_idx + 1 < len(checkpoints): + if start_idx == -1: + ckpt_name = checkpoints[start_idx + 1] + if ckpt_name not in self.var_op_deps: + start_idx += 1 + continue + op_idx_list = self.var_op_deps[ckpt_name]["var_as_output_ops"] + if op_idx_list: + segments.append([0, max(op_idx_list) + 1]) + else: + flag, min_idx, max_idx = self.is_subgraph( + [checkpoints[start_idx]], [checkpoints[start_idx + 1]]) + if flag: + min_idx = self._update_segment_start(min_idx, + pre_segment_end_idx) + segments.append([min_idx, max_idx + 1]) + else: + logging.info("Could not recompute op range [{}] - [{}] ". + format(min_idx, max_idx + 1)) + start_idx += 1 + + for i, (idx1, idx2) in enumerate(segments): + logging.info("recompute segment[{}]".format(i)) + logging.info("segment start op: [{}]: [{}] [{}]".format(self._ops[ + idx1].desc.type(), self._ops[idx1].desc.input_arg_names( + ), self._ops[idx1].desc.output_arg_names())) + logging.info("segment end op: [{}]: [{}] [{}]".format(self._ops[ + idx2 - 1].desc.type(), self._ops[idx2 - 1].desc.input_arg_names( + ), self._ops[idx2 - 1].desc.output_arg_names())) + + return segments + + def modify_forward_desc_for_recompute(self, dist_context): + """ + If program's foward part has 'dropout' op, this function will insert + a seed op before it to guarantee that two dropout op have the same outputs. + """ + op_types = [op.desc.type() for op in self._ops] + if "dropout" not in op_types: + return + + op_idx = 0 + while op_idx < len(self._ops): + cur_op = self._ops[op_idx] + if "grad" in cur_op.type: + break + if cur_op.type != "dropout": + op_idx += 1 + continue + if cur_op.input("Seed") is not None and len(cur_op.input("Seed")): + op_idx += 1 + continue + + cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op) + # insert seed op to guarantee that two dropout op have the same outputs + op_unique_name = unique_name.generate("seed") + var_unique_name = unique_name.generate_with_ignorable_key(".".join( + [op_unique_name, 'tmp'])) + seed_var = self._block.create_var( + name=var_unique_name, + dtype='int32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + + # set new seed_var's dist_attr + ref_dims_mapping = [-1] + ref_process_mesh = cur_op_dist_attr.process_mesh + seed_var_dist_attr = set_var_dist_attr( + dist_context, seed_var, ref_dims_mapping, ref_process_mesh) + + seed = 0 if cur_op.attr("fix_seed") is False else int( + cur_op.attr("seed")) + seed_op = self._block._insert_op_without_sync( + index=cur_op.idx, + type="seed", + inputs={}, + outputs={"Out": seed_var}, + attrs={"seed": seed, + "force_cpu": True}) + # set new seed op's dist_attr + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + seed_op, ref_process_mesh, ref_dims_mapping, dist_context) + + # modify dropout op's desc + self._ops.insert(op_idx, seed_op) + cur_op.desc.set_input("Seed", [var_unique_name]) + cur_op.desc.remove_attr("fix_seed") + cur_op.desc.remove_attr("seed") + cur_op_dist_attr.set_input_dist_attr(seed_var.name, + seed_var_dist_attr) + self._block._sync_with_cpp() + op_idx += 2 + + +def _find_op_index(block, cur_op): + for idx in range(block.desc.op_size()): + if cur_op.desc == block.desc.op(idx): + return idx + return -1 + + +def _get_stop_gradients(program, no_grad_set): + """ get no grad var """ + if no_grad_set is None: + no_grad_set = set() + else: + no_grad_set = _get_no_grad_set_name(no_grad_set) + + no_grad_set_name = set() + for var in program.list_vars(): + assert isinstance(var, Variable) + if "@GRAD" in var.name: + break + if var.stop_gradient: + no_grad_set_name.add(_append_grad_suffix_(var.name)) + no_grad_set_name.update(list(map(_append_grad_suffix_, no_grad_set))) + return no_grad_set_name + + +def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars, + dist_context): + """ + Get the recomputed ops which will insert the backward part + """ + if len(descs) == 0: + return [] + result_descs = [] + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() + backward = core.op_proto_and_checker_maker.OpRole.Backward + for desc in descs: + if isinstance(desc, framework.Operator): + desc = desc.desc + if isinstance(desc, tuple): + desc = desc[0] + is_needed = False + for name in desc.output_arg_names(): + if main_block.has_var(name) and main_block.var(name).persistable: + continue + if name not in in_memory_vars: + is_needed = True + if is_needed: + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(desc) + set_dist_op_desc_original_id(new_op_desc, desc, dist_context) + new_op_desc._set_attr(op_role_attr_name, backward) + result_descs.append(new_op_desc) + return result_descs + + +@register_pass("auto_parallel_recompute") +class RecomputePass(PassBase): + def __init__(self): + super(RecomputePass, self).__init__() + self.set_attr("checkpoints", None) + self.set_attr("loss", None) + self.set_attr("dist_context", None) + self.set_attr("no_grad_set", None) + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + if self.get_attr("loss") is None: + return False + if self.get_attr("checkpoints") is None: + return False + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_programs, startup_programs, context): + checkpoints = self.get_attr("checkpoints") + loss = self.get_attr("loss") + no_grad_set = self.get_attr("no_grad_set") + self._dist_context = self.get_attr("dist_context") + + main_block = main_programs.global_block() + no_grad_set_name = _get_stop_gradients(main_programs, no_grad_set) + # get op_path which is related to loss + op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name) + + # step 1: build recompute state + rc_state = RecomputeState(main_block, op_path) + rc_state.modify_forward_desc_for_recompute(self._dist_context) + rc_state.build_stats() + checkpoints = rc_state.sort_checkpoints(checkpoints) + segments = rc_state.get_recompute_segments(checkpoints) + if segments == []: + return + + # step 2: get vars_should_be_hold + vars_should_be_hold = [] + for segment in segments: + vars_should_be_hold.extend( + rc_state.get_out_of_subgraph_vars(segment[0], segment[1])) + cross_vars = set(vars_should_be_hold) - set(checkpoints) + logging.info("found [{}] vars which cross recompute segment: [{}]," + "better checkpoints might be set to reduce those vars". + format(len(cross_vars), cross_vars)) + vars_should_be_hold.extend(rc_state.get_reserved_vars()) + vars_should_be_hold.extend(rc_state.get_input_nodes()) + vars_should_be_hold = list(set(vars_should_be_hold)) + vars_in_memory = vars_should_be_hold + checkpoints + + # step 3: get recomputed fwd ops desc + var_name_dict = {} + ckpt_ops_dict = {} + buffer_block = main_block.program._create_block() + for i, segment in enumerate(segments[::-1]): + fwd_ops = op_path[segment[0]:segment[1]] + var_suffix = ".subprog_%d" % i + for op in fwd_ops: + input_and_output_names = [] + input_and_output_names.extend(op.desc.input_arg_names()) + input_and_output_names.extend(op.desc.output_arg_names()) + cur_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( + op) + assert cur_op_dist_attr is not None + for name in input_and_output_names: + if main_block.var(name).persistable or name in checkpoints: + continue + if name in vars_should_be_hold: + continue + if name not in var_name_dict: + ref_process_mesh = cur_op_dist_attr.process_mesh + if name in op.desc.input_arg_names(): + ref_dims_mapping = cur_op_dist_attr.get_input_dims_mapping( + name) + else: + ref_dims_mapping = cur_op_dist_attr.get_output_dims_mapping( + name) + # record recomputed var's old_name and new_name (old_name.subprog_XXX) + # create new var with new name + var_name_dict[name] = name + var_suffix + ref_var = main_block.var(name) + rc_var = main_block.create_var( + name=var_name_dict[name], + shape=ref_var.shape, + dtype=ref_var.dtype, + type=ref_var.type, + persistable=ref_var.persistable, + stop_gradient=ref_var.stop_gradient) + # set new recomputed var's dist attr + set_var_dist_attr(self._dist_context, rc_var, + ref_dims_mapping, ref_process_mesh) + # get recomputed segment's descs + segment_descs = _add_needed_descs_to_block( + fwd_ops, buffer_block, main_block, vars_in_memory, + self._dist_context) + # rename recomputed ops' input and output var name + for key in var_name_dict: + _rename_arg_(segment_descs, key, var_name_dict[key]) + + # NOTE: one forward op could be correspond to multiple xxx_grad op. + # When traversing all grad_ops in reverse, need to set a flag to indicate + # whether the ckpt and its segment_descs can be used. + ckpt_op = op_path[segment[1] - 1] + ckpt_ops_dict[ckpt_op.desc.id()] = [True, segment_descs] + + # step 4: insert recomputed fwd ops + ops = main_block.ops + loss_op = get_loss_op(main_block) + loss_op_idx = _find_op_index(main_block, loss_op) + dist_op_context = self._dist_context.dist_op_context + assert loss_op_idx != -1 + # Traversing all grad_ops in reverse, and if the fwd op corresponding to reverse op is checkpoints, + # segments ops should be inserted. + for i in range(len(ops) - 1, loss_op_idx, -1): + grad_op = ops[i] + # remove some attrs of dropout_grad op's desc + if grad_op.type == "dropout_grad": + grad_op.desc.remove_attr("fix_seed") + grad_op.desc.remove_attr("seed") + main_block._sync_with_cpp() + + # rename grad op's var_name which is not in 'vars_in_memory' + for key in var_name_dict: + self.reset_op_dist_attr(grad_op, var_name_dict) + _rename_arg_([grad_op.desc], key, var_name_dict[key]) + + # insert recomputed ops + if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: + fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id( + )] + if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]: + idx = grad_op.idx + while idx - 1 >= 0 and ops[idx - 1].type == "sum": + idx -= 1 + segment_descs = ckpt_ops_dict[fwd_op_id][1] + for _, op_desc in reversed(list(enumerate(segment_descs))): + rc_desc = main_block.desc._insert_op(idx) + rc_desc.copy_from(op_desc) + rc_op = Operator(main_block, rc_desc) + main_block.ops.insert(idx, rc_op) + # set recomputed ops' dist attr + fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( + rc_desc.original_id()) + assert fwd_op_dist_attr is not None + self.set_op_dist_attr(rc_op, fwd_op_dist_attr, + var_name_dict) + + ckpt_ops_dict[fwd_op_id][0] = False + main_block._sync_with_cpp() + + main_programs._sync_with_cpp() + + def reset_op_dist_attr(self, op, var_name_dict): + op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) + assert op_dist_attr is not None + for input in op.desc.input_arg_names(): + if input in var_name_dict.keys(): + in_dist_attr = op_dist_attr.get_input_dist_attr(input) + op_dist_attr.set_input_dist_attr(var_name_dict[input], + in_dist_attr) + for output in op.desc.output_arg_names(): + if output in var_name_dict.keys(): + out_dist_attr = op_dist_attr.get_output_dist_attr(output) + op_dist_attr.set_output_dist_attr(var_name_dict[output], + out_dist_attr) + + def set_op_dist_attr(self, op, old_dist_attr, var_name_dict): + new_dist_attr = OperatorDistributedAttribute() + new_dist_attr.is_recompute = True + new_dist_attr.impl_idx = old_dist_attr.impl_idx + new_dist_attr.process_mesh = old_dist_attr.process_mesh + for input in old_dist_attr.inputs_dist_attrs.keys(): + if input in var_name_dict.keys(): + in_dist_attr = old_dist_attr.inputs_dist_attrs[input] + new_dist_attr.set_input_dist_attr(var_name_dict[input], + in_dist_attr) + else: + in_dist_attr = old_dist_attr.inputs_dist_attrs[input] + new_dist_attr.set_input_dist_attr(input, in_dist_attr) + for output in old_dist_attr.outputs_dist_attrs.keys(): + if output in var_name_dict.keys(): + out_dist_attr = old_dist_attr.outputs_dist_attrs[output] + new_dist_attr.set_output_dist_attr(var_name_dict[output], + out_dist_attr) + else: + out_dist_attr = old_dist_attr.outputs_dist_attrs[output] + new_dist_attr.set_output_dist_attr(output, out_dist_attr) + self._dist_context.set_op_dist_attr_for_program(op, new_dist_attr) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index b9ba724f2c556ffb88f7f619ad9568dfb5f5c22b..b77f9e249c265c4a49b1c1f7df040c1072f9312e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -894,7 +894,6 @@ class GPTModel(nn.Layer): "dims_mapping": [0] + [-1 for i in range(len(input_ids.shape) - 1)] }) - attention_mask.stop_gradient = True encoder_outputs = self.decoder( embedding_output, memory=None, diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index 42bdf678242206f62592dd1359b8b15f7c59a1c8..e024ef1d5d1900efc900808dfd5981db535cc930 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -110,14 +110,8 @@ class AutoPallelPassTestBase(DistPassTestBase): elif strategy == "mp": modeling._global_parallel_strategy = "mp" modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) - elif strategy == "pp": - modeling._global_parallel_strategy = "pp" - modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) - modeling.PP_MESH_LIST = [ - auto.ProcessMesh(mesh=[0]), auto.ProcessMesh(mesh=[1]) - ] else: - raise ValueError("'get_gpt_model' only support dp, mp and pp.") + raise ValueError("'get_gpt_model' only support dp and mp.") tokens = paddle.static.data( name="tokens", shape=[batch_size, sequence_len], dtype='int64') diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..1875c8b1da983f965b3477d78b4a28768ef91efe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_recompute_pass.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import random +import numpy as np + +import unittest +import paddle +import paddle.nn as nn +import paddle.distributed.fleet as fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.passes import new_pass, PassManager +from auto_parallel_pass_test_base import AutoPallelPassTestBase + + +class TestRecomputePass(AutoPallelPassTestBase): + def init(self): + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + self.rtol = 1e-6 + self.atol = 1e-8 + + rank = paddle.distributed.get_rank() + paddle.seed(rank + 2021) + random.seed(rank + 2021) + np.random.seed(rank + 2021) + + def apply_passes(self): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.recompute = True + dist_strategy.recompute_configs = {"checkpoints": ["tmp3", "tmp6"]} + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + + def test_bs_8(self): + self.check_main( + gpus=[0, 1], batch_size=8, sequence_len=512, vocab_size=1000) + + def get_model(self, place, batch_size, sequence_len, vocab_size): + return self.get_gpt_model("mp", place, batch_size, sequence_len, + vocab_size) + + +class TestRecomputePassDP(TestRecomputePass): + def get_model(self, place, batch_size, sequence_len, vocab_size): + return self.get_gpt_model("dp", place, batch_size, sequence_len, + vocab_size) + + +if __name__ == "__main__": + unittest.main()