From 5620214e5aba507dbe4084597be3e46eb2cb8705 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Thu, 30 Dec 2021 16:42:09 +0800 Subject: [PATCH] [Auto parallel] Make sure the id semantics of every var and op unique (#38132) * [Auto parallel] Make the id of var and op unique * [Auto Parallel] Rename back dist_context to distop_context --- paddle/fluid/framework/ir/node.h | 14 +- paddle/fluid/framework/op_desc.cc | 10 +- paddle/fluid/framework/op_desc.h | 27 ++-- paddle/fluid/framework/var_desc.h | 32 +++-- paddle/fluid/pybind/ir.cc | 1 + paddle/fluid/pybind/protobuf.cc | 4 + .../distributed/auto_parallel/completion.py | 8 +- .../distributed/auto_parallel/dist_context.py | 126 +++++++++++------- .../auto_parallel/operators/dist_default.py | 4 + .../auto_parallel/operators/dist_matmul.py | 12 +- .../auto_parallel/operators/dist_reshape.py | 3 + .../distributed/auto_parallel/parallelizer.py | 1 + .../distributed/auto_parallel/partitioner.py | 6 +- .../paddle/distributed/auto_parallel/utils.py | 16 +++ python/paddle/fluid/backward.py | 15 ++- 15 files changed, 185 insertions(+), 94 deletions(-) diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index f4cca78b6d..7e61d6ae42 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -122,6 +122,10 @@ class Node { // Please don't use this API! int id() const { return id_; } + // Only use this for auto parallel. + // A node does not have original desc if the return is zero. + uint64_t OriginalDescId() const { return original_desc_id_; } + bool IsOp() const { return type_ == Type::kOperation; } bool IsVar() const { return type_ == Type::kVariable; } bool IsCtrlVar() const { @@ -239,6 +243,10 @@ class Node { int desc_order_; int block_id_{-1}; + // Store the original id of var desc or op desc. + // Only use this for auto parallel. + uint64_t original_desc_id_{0}; + private: // ID can only set by a Graph. void SetId(int id) { id_ = id; } @@ -267,14 +275,16 @@ class Node { op_desc_(nullptr), type_(Type::kVariable), desc_order_(NO_DESC_ORDER), - block_id_(block_id) {} + block_id_(block_id), + original_desc_id_(var_desc->OriginalId()) {} explicit Node(OpDesc* op_desc) : name_(op_desc->Type()), var_desc_(nullptr), op_desc_(new OpDesc(*op_desc, op_desc->Block())), type_(Type::kOperation), - desc_order_(NO_DESC_ORDER) {} + desc_order_(NO_DESC_ORDER), + original_desc_id_(op_desc->OriginalId()) {} Node() = delete; diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 670a269391..4254ec236d 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -352,15 +352,9 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { inputs_ = op_desc.inputs_; outputs_ = op_desc.outputs_; attrs_ = op_desc.attrs_; + // The record of original_id_ is only for auto parallel. + original_id_ = op_desc.original_id_; need_update_ = true; - // When creating graph from program, the creation of op node will create a new - // OpDesc instead of - // referring to the original one. To find the original OpDesc of the op node, - // the id have to be - // copied to the new OpDesc. The var node has the same situation, but the - // default copy constructor - // can copy the id automatically. - id_ = op_desc.id_; } OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 9470fd9b69..82e15d40be 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -154,17 +154,10 @@ class OpDesc { const BlockDesc *Block() const { return this->block_; } - // This thread-safe implementation seems to be redudent since the neural - // networks - // are usually constructed in a single thread - static uint64_t GenerateId() { - static std::atomic id{0}; - return ++id; - } - - // Note: the identity only used as a key for referring to its - // distributed attribute now. + // The Id() and OrignalId() are only used for auto parallel. uint64_t Id() const { return id_; } + uint64_t OriginalId() const { return original_id_; } + void SetOriginalId(uint64_t original_id) { original_id_ = original_id; } private: template @@ -177,6 +170,14 @@ class OpDesc { return ret_val; } + // This thread-safe implementation seems to be redudent since the neural + // networks are usually constructed in a single thread + static uint64_t GenerateId() { + static std::atomic uid{0}; + // Must start from one + return ++uid; + } + proto::OpDesc desc_; BlockDesc *block_{nullptr}; // not_own // input arg name => input variable names @@ -189,7 +190,13 @@ class OpDesc { // local changes should be synchronized, need_update_ should be set to true. bool need_update_{false}; + // Note: the id_ is unique (only for auto parallel). uint64_t id_ = GenerateId(); + // Note: the orignal_id_ is used for referring to the original OpDesc + // that the current OpDesc is built from (only for auto parallel). + // The default original_id_ is same as the id_, which means the + // current OpDesc is not built from the other one. + uint64_t original_id_ = id_; }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index afe420dd25..a20ef58f9c 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -69,6 +69,12 @@ class VarDesc { explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {} + // Explicitly implement the copy constructor for auto parallel + VarDesc(const VarDesc &other) + : desc_(other.desc_), + attrs_(other.attrs_), + original_id_(other.original_id_) {} + proto::VarDesc *Proto() { return &desc_; } const proto::VarDesc *Proto() const { return &desc_; } @@ -153,16 +159,10 @@ class VarDesc { Attribute GetAttr(const std::string &name) const; - // This thread-safe implementation seems to be redudent since the neural - // networks are usually constructed in a single thread. - static uint64_t GenerateId() { - static std::atomic uid{0}; - return ++uid; - } - - // Note: the identity only used as a key for referring to its - // distributed attribute now. + // The Id() and OriginalId() are only used for auto parallel. uint64_t Id() const { return id_; } + uint64_t OriginalId() const { return original_id_; } + void SetOriginalId(uint64_t original_id) { original_id_ = original_id; } private: const proto::VarType::TensorDesc &tensor_desc() const; @@ -170,9 +170,23 @@ class VarDesc { proto::VarType::TensorDesc *mutable_tensor_desc(); std::vector mutable_tensor_descs(); + // This thread-safe implementation seems to be redudent since the neural + // networks are usually constructed in a single thread. + static uint64_t GenerateId() { + static std::atomic uid{0}; + return ++uid; + } + proto::VarDesc desc_; AttributeMap attrs_; + + // Note: the id_ is unique for all VarDesc (only for auto parallel). uint64_t id_ = GenerateId(); + // Note: the orignal_id_ is used for referring to the original VarDesc + // that the current VarDesc is built from (only for auto parallel). + // The default original_id_ is same as the id_, which means the + // current VarDesc is not built from the other one. + uint64_t original_id_ = id_; }; bool operator==(const VarDesc &left, const VarDesc &right); diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index f2fb4671df..bb45c1c406 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -143,6 +143,7 @@ void BindNode(py::module *m) { .def("var", &Node::Var, return_value_policy::reference) .def("op", &Node::Op, return_value_policy::reference) .def("id", &Node::id) + .def("original_desc_id", &Node::OriginalDescId) .def("is_op", &Node::IsOp) .def("is_var", &Node::IsVar) .def("is_ctrl_var", &Node::IsCtrlVar) diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 44a8a54c8c..66bf8c9517 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -208,6 +208,8 @@ void BindVarDsec(pybind11::module *m) { .def("_set_attr", &pd::VarDesc::SetAttr) .def("remove_attr", &pd::VarDesc::RemoveAttr) .def("id", &pd::VarDesc::Id) + .def("original_id", &pd::VarDesc::OriginalId) + .def("set_original_id", &pd::VarDesc::SetOriginalId) .def("attr", &pd::VarDesc::GetAttr); pybind11::enum_ vartype(var_desc, "VarType", ""); @@ -305,6 +307,8 @@ void BindOpDesc(pybind11::module *m) { .def("block", [](pd::OpDesc &self) { return self.Block(); }, pybind11::return_value_policy::reference) .def("id", &pd::OpDesc::Id) + .def("original_id", &pd::OpDesc::OriginalId) + .def("set_original_id", &pd::OpDesc::SetOriginalId) .def("inputs", &pd::OpDesc::Inputs) .def("outputs", &pd::OpDesc::Outputs); } diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 745a018e8c..b038581192 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -698,13 +698,13 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): continue # complete the annotation of grad op (xxx_grad op or sum op) - # xxx_grad op will have a corresponding forward op in gradopidx2opidx + # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id grad_op = ops[idx] - if grad_op.desc.id() in dist_op_context.gradopidx2opidx: + if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: # TODO support the case where one forward op corresponding to multiple xxx_grad op forward_op = _get_op_by_id( ops[:first_backward_op_idx], - dist_op_context.gradopidx2opidx[grad_op.desc.id()]) + dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) assert forward_op is not None # op dist attr @@ -769,7 +769,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): dist_context.set_op_dist_attr_for_program(grad_op, grad_op_dist_attr) - # only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx + # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id else: assert grad_op.type == "sum", "got unexpect op [{}]".format( str(grad_op.type)) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 3ec63fa116..fb8fd8948f 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -46,14 +46,19 @@ class DistributedContext: """ def __init__(self, program=None): + # Program related data members self._serial_program = program - self._serial_graph = None self._is_initialized_for_program = False - self._is_initialized_for_graph = False self._dist_tensors_for_program = {} self._dist_ops_for_program = {} + # Graph related data members + self._is_initialized_for_graph = False + self._serial_graph = None self._dist_tensors_for_graph = {} self._dist_ops_for_graph = {} + self._node_id_to_tensor_id = {} + self._node_id_to_op_id = {} + # Other data members self._dist_op_context = DistributedOperatorContext() self._process_meshes = [] @@ -97,19 +102,38 @@ class DistributedContext: def get_dist_tensor_for_program(self, serial_tensor): serial_tensor_id = serial_tensor.desc.id() - return self._dist_tensors_for_program.get(serial_tensor_id, None) + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) + if dist_tensor: + return dist_tensor + else: + serial_tensor_id = serial_tensor.desc.original_id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, + None) + if dist_tensor: + return dist_tensor + else: + return None def get_dist_tensor_for_graph(self, serial_tensor_node): serial_tensor_node_id = serial_tensor_node.id() return self._dist_tensors_for_graph.get(serial_tensor_node_id, None) - def get_dist_op_for_program(self, serial_tensor): - serial_tensor_id = serial_tensor.desc.id() - return self._dist_ops_for_program.get(serial_tensor_id, None) + def get_dist_op_for_program(self, serial_op): + serial_op_id = serial_op.desc.id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + if dist_op: + return dist_op + else: + serial_op_id = serial_op.desc.original_id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + if dist_op: + return dist_op + else: + return None - def get_dist_op_for_graph(self, serial_tensor_node): - serial_tensor_node_id = serial_tensor_node.id() - return self._dist_ops_for_graph.get(serial_tensor_node_id, None) + def get_dist_op_for_graph(self, serial_op_node): + serial_op_node_id = serial_op_node.id() + return self._dist_ops_for_graph.get(serial_op_node_id, None) def get_tensor_dist_attr_for_program(self, serial_tensor): serial_tensor_id = serial_tensor.desc.id() @@ -117,7 +141,13 @@ class DistributedContext: if dist_tensor: return dist_tensor.dist_attr else: - return None + serial_tensor_id = serial_tensor.desc.original_id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, + None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr): dist_tensor = DistributedTensor(serial_tensor, dist_attr) @@ -132,25 +162,18 @@ class DistributedContext: else: return None - def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr): - assert serial_tensor_node.is_var() and \ - serial_tensor_node.var() is not None - serial_tensor_id = serial_tensor_node.var().id() - dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) - assert dist_tensor is not None, \ - "The distributed tensor of the program has not been added to this context." - serial_tensor_node_id = serial_tensor_node.id() - new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, - dist_attr) - self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor - def get_op_dist_attr_for_program(self, serial_op): serial_op_id = serial_op.desc.id() dist_op = self._dist_ops_for_program.get(serial_op_id, None) if dist_op: return dist_op.dist_attr else: - return None + serial_op_id = serial_op.desc.original_id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None def set_op_dist_attr_for_program(self, serial_op, dist_attr): dist_op = DistributedOperator(serial_op, dist_attr) @@ -164,17 +187,6 @@ class DistributedContext: else: return None - def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr): - assert serial_op_node.is_op() and \ - serial_op_node.op() is not None - serial_op_id = serial_op_node.op().id() - dist_op = self._dist_ops_for_program.get(serial_op_id, None) - assert dist_op is not None, \ - "The distributed operator of the program has not been added to this context." - serial_op_node_id = serial_op_node.id() - new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) - self._dist_ops_for_graph[serial_op_node_id] = new_dist_op - def init_dist_attr_for_program(self): assert self._serial_program, \ "Please set the program of this context before initializing its distribute attributes." @@ -216,20 +228,36 @@ class DistributedContext: all_nodes = self._serial_graph.all_nodes() for node in all_nodes: if node.is_var() and node.var() is not None: - tensor_desc = node.var() - tensor_id = tensor_desc.id() - dist_tensor = self._dist_tensors_for_program.get(tensor_id, - None) + dist_tensor = None + tensor_id = node.node.original_desc_id() + for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items( + ): + if tensor_id == cur_tensor_id \ + or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id(): + dist_tensor = cur_dist_tensor + self._node_id_to_tensor_id[node.id()] = cur_tensor_id assert dist_tensor is not None, \ "Tensor must have a distributed tensor after the initialization for program." - self.set_tensor_dist_attr_for_graph(node, dist_tensor.dist_attr) + serial_tensor_node_id = node.id() + new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, + dist_tensor.dist_attr) + self._dist_tensors_for_graph[ + serial_tensor_node_id] = new_dist_tensor if node.is_op() and node.op() is not None: - op_desc = node.op() - op_id = op_desc.id() - dist_op = self._dist_ops_for_program.get(op_id, None) + dist_op = None + op_id = node.node.original_desc_id() + for cur_op_id, cur_dist_op in self._dist_ops_for_program.items( + ): + if op_id == cur_op_id \ + or op_id == cur_dist_op.serial_op.desc.original_id(): + dist_op = cur_dist_op + self._node_id_to_op_id[node.id()] = cur_op_id assert dist_op is not None, \ "Operator must have a distributed operator after the initialization for program." - self.set_op_dist_attr_for_graph(node, dist_op.dist_attr) + serial_op_node_id = node.id() + new_dist_op = DistributedOperator(dist_op.serial_op, + dist_op.dist_attr) + self._dist_ops_for_graph[serial_op_node_id] = new_dist_op self._is_initialized_for_graph = True def clear_dist_info_for_program(self): @@ -247,9 +275,8 @@ class DistributedContext: all_nodes = self._serial_graph.all_nodes() for node in all_nodes: if node.is_var() and node.var() is not None: - tensor_desc = node.var() - tensor_id = tensor_desc.id() - updated = updated_tensors.get(tensor_desc.name(), False) + tensor_id = self._node_id_to_tensor_id[node.id()] + updated = updated_tensors.get(tensor_id, False) # If a var has multiples var nodes in graph, only use the first one for now if not updated: tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph( @@ -257,10 +284,9 @@ class DistributedContext: dist_tensor_for_program = self._dist_tensors_for_program[ tensor_id] dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph - updated_tensors[tensor_desc.name()] = True + updated_tensors[tensor_id] = True if node.is_op() and node.op() is not None: - op_desc = node.op() - op_id = op_desc.id() + op_id = self._node_id_to_op_id[node.id()] op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node) dist_op_for_program = self._dist_ops_for_program[op_id] dist_op_for_program.dist_attr = op_dist_attr_for_graph @@ -360,7 +386,7 @@ class DistributedOperatorContext: self._rank_id = None self._cur_src_op = None self._cur_dist_attr = None - self.gradopidx2opidx = {} + self.grad_op_id_to_op_id = {} self.already_init_sync_vars = set() def __deepcopy__(self, memo): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index e2ebf1cfe6..1a3d57bf14 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -22,6 +22,7 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from ..utils import set_dist_op_desc_original_id from ..dist_attribute import OperatorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode @@ -86,6 +87,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): # replicate op in dist program dist_op_desc = main_block.desc.append_op() dist_op_desc.copy_from(src_op.desc) + set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): dist_op_desc.set_input(input_name, kwargs[input_name]) for output_name in src_op.desc.output_names(): @@ -170,6 +172,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): # replicate op in dist program dist_op_desc = main_block.desc.append_op() dist_op_desc.copy_from(backward_op.desc) + # Refer to the related dist op + set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) for input_name in backward_op.desc.input_names(): dist_op_desc.set_input(input_name, kwargs[input_name]) for output_name in backward_op.desc.output_names(): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 9b0bdabc6d..f4c31c3654 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -25,6 +25,7 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from ..utils import set_dist_op_desc_original_id from ..dist_attribute import OperatorDistributedAttribute from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode @@ -35,9 +36,10 @@ from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank -def copy_op_with_new_input_output(block, src_op, **kwargs): +def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): dist_op_desc = block.desc.append_op() dist_op_desc.copy_from(src_op.desc) + set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) for input_name in src_op.desc.input_names(): assert input_name in kwargs dist_op_desc.set_input(input_name, kwargs[input_name]) @@ -253,7 +255,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): new_kwargs = copy.deepcopy(kwargs) new_kwargs['Out@GRAD'] = [intermediate_var_0.name] matmul_op_desc = copy_op_with_new_input_output( - main_block, backward_op, **new_kwargs) + ctx, main_block, backward_op, **new_kwargs) else: # col parallel: matmul + allreduce assert Y_var_dim_mapping[0] < 0 @@ -281,7 +283,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): new_kwargs['X@GRAD'] = [intermediate_var_0.name] matmul_op_desc = copy_op_with_new_input_output( - main_block, backward_op, **new_kwargs) + ctx, main_block, backward_op, **new_kwargs) # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad if has_x_grad: @@ -304,8 +306,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): X_grad_dist_attr, ctx) else: # replicate - matmul_op_desc = copy_op_with_new_input_output(main_block, backward_op, - **kwargs) + matmul_op_desc = copy_op_with_new_input_output(ctx, main_block, + backward_op, **kwargs) main_block._sync_with_cpp() diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index aba9704ad5..e287bd75b3 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -22,6 +22,7 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from ..utils import set_dist_op_desc_original_id from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard @@ -181,6 +182,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): # create op new_op_desc = main_block.desc.append_op() new_op_desc.copy_from(src_op.desc) + set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('Shape', Shape_var_list) new_op_desc.set_input('X', [X_var.name]) @@ -345,6 +347,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): # create op new_op_desc = main_block.desc.append_op() new_op_desc.copy_from(src_op.desc) + set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('Shape', Shape_var_list) new_op_desc.set_input('X', [X_var.name]) diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 04d5f1db59..0042dd8e82 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -39,6 +39,7 @@ from .process_group import get_world_process_groups from .process_group import _g_process_group_map, ProcessGroup from .utils import make_data_unshard from .utils import set_grad_var_shape +from .utils import print_program_with_dist_attr from .utils import SerialProgramInfo from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from .cluster import Cluster diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 096de1c206..76a9faa1c8 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -24,6 +24,7 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext from .dist_attribute import OperatorDistributedAttribute from .process_group import new_process_group +from .utils import set_dist_op_desc_original_id from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op from .operators.common import BACKWARD_ONLY_DIST_OPS @@ -363,8 +364,9 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, def _get_dist_op_backward_implement(backward_op, dist_context, forward_op_id2forward_op): dist_op_context = dist_context.dist_op_context - if backward_op.desc.id() in dist_op_context.gradopidx2opidx: - forward_op_id = dist_op_context.gradopidx2opidx[backward_op.desc.id()] + if backward_op.desc.id() in dist_op_context.grad_op_id_to_op_id: + forward_op_id = dist_op_context.grad_op_id_to_op_id[backward_op.desc.id( + )] forward_op = forward_op_id2forward_op[forward_op_id] forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 5198b8f5fd..2316f207ff 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1402,3 +1402,19 @@ def get_standalone_cost_data(distributed_programs): standalone_cost_data.append(cost_data) return standalone_cost_data + + +def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context): + op_id = op_desc.id() + op_original_id = op_desc.original_id() + # First, try to set the original id to the id of the op_desc + if op_id in dist_context._dist_ops_for_program: + dist_op_desc.set_original_id(op_id) + return + # Second, try to set the original id to the original_id of the op_desc + elif op_original_id in dist_context._dist_ops_for_program: + dist_op_desc.set_original_id(op_original_id) + return + # Third, print error infomation if we cannot find the original id + else: + assert False, "Cannot find the original id in the distributed context" diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 9ea407c760..4805994b7a 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -202,7 +202,7 @@ class ProgramStats(object): if op.desc.has_attr(op_device_attr_name): op_device = op.desc.attr(op_device_attr_name) - # Setting the force_cpu of seed to true will make the output of seed in cpu memory, + # Setting the force_cpu of seed to true will make the output of seed in cpu memory, # reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang added_op = self.block._insert_op( index=op.idx, @@ -957,7 +957,7 @@ def _append_backward_ops_with_checkpoints_( # added_descs should be in grad_op_descs because it is backward op desc grad_op_descs.extend(buffer_descs) - # 3.c. add backward ops for all ops in current segment + # 3.c. add backward ops for all ops in current segment for op_desc in reversed(added_descs): grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op_desc, cpt.to_text(no_grad_dict[block.idx]), []) @@ -1109,10 +1109,11 @@ def _append_backward_ops_(block, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) + # Build the mapping between the forward op and bacckward op (Only for auto parallel) if distop_context is not None: for op_desc in grad_op_desc: - assert op_desc.id() not in distop_context.gradopidx2opidx - distop_context.gradopidx2opidx[op_desc.id()] = op.desc.id() + assert op_desc.id() not in distop_context.grad_op_id_to_op_id + distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id() # Set device for grad_op according to forward Op device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() @@ -1197,6 +1198,12 @@ def _append_backward_ops_(block, for op_desc in grad_op_descs: new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op_desc) + # Rebuild the mapping because new_op_desc has a differnt id (Only for auto parallel) + if distop_context is not None: + if op_desc.id() in distop_context.grad_op_id_to_op_id: + distop_context.grad_op_id_to_op_id[new_op_desc.id( + )] = distop_context.grad_op_id_to_op_id[op_desc.id()] + distop_context.grad_op_id_to_op_id.pop(op_desc.id()) new_op_desc._set_attr(op_role_attr_name, backward) grad_to_var["__current_op_desc__"] = new_op_desc if callbacks is not None: -- GitLab