未验证 提交 5620214e 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto parallel] Make sure the id semantics of every var and op unique (#38132)

* [Auto parallel] Make the id of var and op unique

* [Auto Parallel] Rename back dist_context to distop_context
上级 ccf99b66
...@@ -122,6 +122,10 @@ class Node { ...@@ -122,6 +122,10 @@ class Node {
// Please don't use this API! // Please don't use this API!
int id() const { return id_; } int id() const { return id_; }
// Only use this for auto parallel.
// A node does not have original desc if the return is zero.
uint64_t OriginalDescId() const { return original_desc_id_; }
bool IsOp() const { return type_ == Type::kOperation; } bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; } bool IsVar() const { return type_ == Type::kVariable; }
bool IsCtrlVar() const { bool IsCtrlVar() const {
...@@ -239,6 +243,10 @@ class Node { ...@@ -239,6 +243,10 @@ class Node {
int desc_order_; int desc_order_;
int block_id_{-1}; int block_id_{-1};
// Store the original id of var desc or op desc.
// Only use this for auto parallel.
uint64_t original_desc_id_{0};
private: private:
// ID can only set by a Graph. // ID can only set by a Graph.
void SetId(int id) { id_ = id; } void SetId(int id) { id_ = id; }
...@@ -267,14 +275,16 @@ class Node { ...@@ -267,14 +275,16 @@ class Node {
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable), type_(Type::kVariable),
desc_order_(NO_DESC_ORDER), desc_order_(NO_DESC_ORDER),
block_id_(block_id) {} block_id_(block_id),
original_desc_id_(var_desc->OriginalId()) {}
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())), op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation), type_(Type::kOperation),
desc_order_(NO_DESC_ORDER) {} desc_order_(NO_DESC_ORDER),
original_desc_id_(op_desc->OriginalId()) {}
Node() = delete; Node() = delete;
......
...@@ -352,15 +352,9 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { ...@@ -352,15 +352,9 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
inputs_ = op_desc.inputs_; inputs_ = op_desc.inputs_;
outputs_ = op_desc.outputs_; outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_; attrs_ = op_desc.attrs_;
// The record of original_id_ is only for auto parallel.
original_id_ = op_desc.original_id_;
need_update_ = true; need_update_ = true;
// When creating graph from program, the creation of op node will create a new
// OpDesc instead of
// referring to the original one. To find the original OpDesc of the op node,
// the id have to be
// copied to the new OpDesc. The var node has the same situation, but the
// default copy constructor
// can copy the id automatically.
id_ = op_desc.id_;
} }
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
......
...@@ -154,17 +154,10 @@ class OpDesc { ...@@ -154,17 +154,10 @@ class OpDesc {
const BlockDesc *Block() const { return this->block_; } const BlockDesc *Block() const { return this->block_; }
// This thread-safe implementation seems to be redudent since the neural // The Id() and OrignalId() are only used for auto parallel.
// networks
// are usually constructed in a single thread
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> id{0};
return ++id;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
uint64_t Id() const { return id_; } uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
private: private:
template <typename MapType> template <typename MapType>
...@@ -177,6 +170,14 @@ class OpDesc { ...@@ -177,6 +170,14 @@ class OpDesc {
return ret_val; return ret_val;
} }
// This thread-safe implementation seems to be redudent since the neural
// networks are usually constructed in a single thread
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
// Must start from one
return ++uid;
}
proto::OpDesc desc_; proto::OpDesc desc_;
BlockDesc *block_{nullptr}; // not_own BlockDesc *block_{nullptr}; // not_own
// input arg name => input variable names // input arg name => input variable names
...@@ -189,7 +190,13 @@ class OpDesc { ...@@ -189,7 +190,13 @@ class OpDesc {
// local changes should be synchronized, need_update_ should be set to true. // local changes should be synchronized, need_update_ should be set to true.
bool need_update_{false}; bool need_update_{false};
// Note: the id_ is unique (only for auto parallel).
uint64_t id_ = GenerateId(); uint64_t id_ = GenerateId();
// Note: the orignal_id_ is used for referring to the original OpDesc
// that the current OpDesc is built from (only for auto parallel).
// The default original_id_ is same as the id_, which means the
// current OpDesc is not built from the other one.
uint64_t original_id_ = id_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -69,6 +69,12 @@ class VarDesc { ...@@ -69,6 +69,12 @@ class VarDesc {
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {} explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
// Explicitly implement the copy constructor for auto parallel
VarDesc(const VarDesc &other)
: desc_(other.desc_),
attrs_(other.attrs_),
original_id_(other.original_id_) {}
proto::VarDesc *Proto() { return &desc_; } proto::VarDesc *Proto() { return &desc_; }
const proto::VarDesc *Proto() const { return &desc_; } const proto::VarDesc *Proto() const { return &desc_; }
...@@ -153,16 +159,10 @@ class VarDesc { ...@@ -153,16 +159,10 @@ class VarDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
// This thread-safe implementation seems to be redudent since the neural // The Id() and OriginalId() are only used for auto parallel.
// networks are usually constructed in a single thread.
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
return ++uid;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
uint64_t Id() const { return id_; } uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
private: private:
const proto::VarType::TensorDesc &tensor_desc() const; const proto::VarType::TensorDesc &tensor_desc() const;
...@@ -170,9 +170,23 @@ class VarDesc { ...@@ -170,9 +170,23 @@ class VarDesc {
proto::VarType::TensorDesc *mutable_tensor_desc(); proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
// This thread-safe implementation seems to be redudent since the neural
// networks are usually constructed in a single thread.
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
return ++uid;
}
proto::VarDesc desc_; proto::VarDesc desc_;
AttributeMap attrs_; AttributeMap attrs_;
// Note: the id_ is unique for all VarDesc (only for auto parallel).
uint64_t id_ = GenerateId(); uint64_t id_ = GenerateId();
// Note: the orignal_id_ is used for referring to the original VarDesc
// that the current VarDesc is built from (only for auto parallel).
// The default original_id_ is same as the id_, which means the
// current VarDesc is not built from the other one.
uint64_t original_id_ = id_;
}; };
bool operator==(const VarDesc &left, const VarDesc &right); bool operator==(const VarDesc &left, const VarDesc &right);
......
...@@ -143,6 +143,7 @@ void BindNode(py::module *m) { ...@@ -143,6 +143,7 @@ void BindNode(py::module *m) {
.def("var", &Node::Var, return_value_policy::reference) .def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference) .def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id) .def("id", &Node::id)
.def("original_desc_id", &Node::OriginalDescId)
.def("is_op", &Node::IsOp) .def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar) .def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar) .def("is_ctrl_var", &Node::IsCtrlVar)
......
...@@ -208,6 +208,8 @@ void BindVarDsec(pybind11::module *m) { ...@@ -208,6 +208,8 @@ void BindVarDsec(pybind11::module *m) {
.def("_set_attr", &pd::VarDesc::SetAttr) .def("_set_attr", &pd::VarDesc::SetAttr)
.def("remove_attr", &pd::VarDesc::RemoveAttr) .def("remove_attr", &pd::VarDesc::RemoveAttr)
.def("id", &pd::VarDesc::Id) .def("id", &pd::VarDesc::Id)
.def("original_id", &pd::VarDesc::OriginalId)
.def("set_original_id", &pd::VarDesc::SetOriginalId)
.def("attr", &pd::VarDesc::GetAttr); .def("attr", &pd::VarDesc::GetAttr);
pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", ""); pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
...@@ -305,6 +307,8 @@ void BindOpDesc(pybind11::module *m) { ...@@ -305,6 +307,8 @@ void BindOpDesc(pybind11::module *m) {
.def("block", [](pd::OpDesc &self) { return self.Block(); }, .def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("id", &pd::OpDesc::Id) .def("id", &pd::OpDesc::Id)
.def("original_id", &pd::OpDesc::OriginalId)
.def("set_original_id", &pd::OpDesc::SetOriginalId)
.def("inputs", &pd::OpDesc::Inputs) .def("inputs", &pd::OpDesc::Inputs)
.def("outputs", &pd::OpDesc::Outputs); .def("outputs", &pd::OpDesc::Outputs);
} }
......
...@@ -698,13 +698,13 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -698,13 +698,13 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
continue continue
# complete the annotation of grad op (xxx_grad op or sum op) # complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in gradopidx2opidx # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx] grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.gradopidx2opidx: if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op # TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id( forward_op = _get_op_by_id(
ops[:first_backward_op_idx], ops[:first_backward_op_idx],
dist_op_context.gradopidx2opidx[grad_op.desc.id()]) dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
assert forward_op is not None assert forward_op is not None
# op dist attr # op dist attr
...@@ -769,7 +769,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -769,7 +769,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
dist_context.set_op_dist_attr_for_program(grad_op, dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr) grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id
else: else:
assert grad_op.type == "sum", "got unexpect op [{}]".format( assert grad_op.type == "sum", "got unexpect op [{}]".format(
str(grad_op.type)) str(grad_op.type))
......
...@@ -46,14 +46,19 @@ class DistributedContext: ...@@ -46,14 +46,19 @@ class DistributedContext:
""" """
def __init__(self, program=None): def __init__(self, program=None):
# Program related data members
self._serial_program = program self._serial_program = program
self._serial_graph = None
self._is_initialized_for_program = False self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._dist_tensors_for_program = {} self._dist_tensors_for_program = {}
self._dist_ops_for_program = {} self._dist_ops_for_program = {}
# Graph related data members
self._is_initialized_for_graph = False
self._serial_graph = None
self._dist_tensors_for_graph = {} self._dist_tensors_for_graph = {}
self._dist_ops_for_graph = {} self._dist_ops_for_graph = {}
self._node_id_to_tensor_id = {}
self._node_id_to_op_id = {}
# Other data members
self._dist_op_context = DistributedOperatorContext() self._dist_op_context = DistributedOperatorContext()
self._process_meshes = [] self._process_meshes = []
...@@ -97,19 +102,38 @@ class DistributedContext: ...@@ -97,19 +102,38 @@ class DistributedContext:
def get_dist_tensor_for_program(self, serial_tensor): def get_dist_tensor_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id() serial_tensor_id = serial_tensor.desc.id()
return self._dist_tensors_for_program.get(serial_tensor_id, None) dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
if dist_tensor:
return dist_tensor
else:
serial_tensor_id = serial_tensor.desc.original_id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id,
None)
if dist_tensor:
return dist_tensor
else:
return None
def get_dist_tensor_for_graph(self, serial_tensor_node): def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id() serial_tensor_node_id = serial_tensor_node.id()
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None) return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_program(self, serial_tensor): def get_dist_op_for_program(self, serial_op):
serial_tensor_id = serial_tensor.desc.id() serial_op_id = serial_op.desc.id()
return self._dist_ops_for_program.get(serial_tensor_id, None) dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op
else:
serial_op_id = serial_op.desc.original_id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op
else:
return None
def get_dist_op_for_graph(self, serial_tensor_node): def get_dist_op_for_graph(self, serial_op_node):
serial_tensor_node_id = serial_tensor_node.id() serial_op_node_id = serial_op_node.id()
return self._dist_ops_for_graph.get(serial_tensor_node_id, None) return self._dist_ops_for_graph.get(serial_op_node_id, None)
def get_tensor_dist_attr_for_program(self, serial_tensor): def get_tensor_dist_attr_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id() serial_tensor_id = serial_tensor.desc.id()
...@@ -117,7 +141,13 @@ class DistributedContext: ...@@ -117,7 +141,13 @@ class DistributedContext:
if dist_tensor: if dist_tensor:
return dist_tensor.dist_attr return dist_tensor.dist_attr
else: else:
return None serial_tensor_id = serial_tensor.desc.original_id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id,
None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr): def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr) dist_tensor = DistributedTensor(serial_tensor, dist_attr)
...@@ -132,25 +162,18 @@ class DistributedContext: ...@@ -132,25 +162,18 @@ class DistributedContext:
else: else:
return None return None
def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
assert serial_tensor_node.is_var() and \
serial_tensor_node.var() is not None
serial_tensor_id = serial_tensor_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
assert dist_tensor is not None, \
"The distributed tensor of the program has not been added to this context."
serial_tensor_node_id = serial_tensor_node.id()
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_attr)
self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op): def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id() serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None) dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op: if dist_op:
return dist_op.dist_attr return dist_op.dist_attr
else: else:
return None serial_op_id = serial_op.desc.original_id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_program(self, serial_op, dist_attr): def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr) dist_op = DistributedOperator(serial_op, dist_attr)
...@@ -164,17 +187,6 @@ class DistributedContext: ...@@ -164,17 +187,6 @@ class DistributedContext:
else: else:
return None return None
def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
assert serial_op_node.is_op() and \
serial_op_node.op() is not None
serial_op_id = serial_op_node.op().id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
assert dist_op is not None, \
"The distributed operator of the program has not been added to this context."
serial_op_node_id = serial_op_node.id()
new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def init_dist_attr_for_program(self): def init_dist_attr_for_program(self):
assert self._serial_program, \ assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes." "Please set the program of this context before initializing its distribute attributes."
...@@ -216,20 +228,36 @@ class DistributedContext: ...@@ -216,20 +228,36 @@ class DistributedContext:
all_nodes = self._serial_graph.all_nodes() all_nodes = self._serial_graph.all_nodes()
for node in all_nodes: for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
tensor_desc = node.var() dist_tensor = None
tensor_id = tensor_desc.id() tensor_id = node.node.original_desc_id()
dist_tensor = self._dist_tensors_for_program.get(tensor_id, for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items(
None) ):
if tensor_id == cur_tensor_id \
or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
dist_tensor = cur_dist_tensor
self._node_id_to_tensor_id[node.id()] = cur_tensor_id
assert dist_tensor is not None, \ assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program." "Tensor must have a distributed tensor after the initialization for program."
self.set_tensor_dist_attr_for_graph(node, dist_tensor.dist_attr) serial_tensor_node_id = node.id()
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_tensor.dist_attr)
self._dist_tensors_for_graph[
serial_tensor_node_id] = new_dist_tensor
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
op_desc = node.op() dist_op = None
op_id = op_desc.id() op_id = node.node.original_desc_id()
dist_op = self._dist_ops_for_program.get(op_id, None) for cur_op_id, cur_dist_op in self._dist_ops_for_program.items(
):
if op_id == cur_op_id \
or op_id == cur_dist_op.serial_op.desc.original_id():
dist_op = cur_dist_op
self._node_id_to_op_id[node.id()] = cur_op_id
assert dist_op is not None, \ assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program." "Operator must have a distributed operator after the initialization for program."
self.set_op_dist_attr_for_graph(node, dist_op.dist_attr) serial_op_node_id = node.id()
new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
self._is_initialized_for_graph = True self._is_initialized_for_graph = True
def clear_dist_info_for_program(self): def clear_dist_info_for_program(self):
...@@ -247,9 +275,8 @@ class DistributedContext: ...@@ -247,9 +275,8 @@ class DistributedContext:
all_nodes = self._serial_graph.all_nodes() all_nodes = self._serial_graph.all_nodes()
for node in all_nodes: for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
tensor_desc = node.var() tensor_id = self._node_id_to_tensor_id[node.id()]
tensor_id = tensor_desc.id() updated = updated_tensors.get(tensor_id, False)
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now # If a var has multiples var nodes in graph, only use the first one for now
if not updated: if not updated:
tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph( tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph(
...@@ -257,10 +284,9 @@ class DistributedContext: ...@@ -257,10 +284,9 @@ class DistributedContext:
dist_tensor_for_program = self._dist_tensors_for_program[ dist_tensor_for_program = self._dist_tensors_for_program[
tensor_id] tensor_id]
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_desc.name()] = True updated_tensors[tensor_id] = True
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
op_desc = node.op() op_id = self._node_id_to_op_id[node.id()]
op_id = op_desc.id()
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node) op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id] dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph dist_op_for_program.dist_attr = op_dist_attr_for_graph
...@@ -360,7 +386,7 @@ class DistributedOperatorContext: ...@@ -360,7 +386,7 @@ class DistributedOperatorContext:
self._rank_id = None self._rank_id = None
self._cur_src_op = None self._cur_src_op = None
self._cur_dist_attr = None self._cur_dist_attr = None
self.gradopidx2opidx = {} self.grad_op_id_to_op_id = {}
self.already_init_sync_vars = set() self.already_init_sync_vars = set()
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
......
...@@ -22,6 +22,7 @@ from ..utils import is_valid_list_index ...@@ -22,6 +22,7 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
...@@ -86,6 +87,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -86,6 +87,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name]) dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
...@@ -170,6 +172,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -170,6 +172,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(backward_op.desc) dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name]) dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names(): for output_name in backward_op.desc.output_names():
......
...@@ -25,6 +25,7 @@ from ..utils import is_valid_list_index ...@@ -25,6 +25,7 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
...@@ -35,9 +36,10 @@ from ..process_group import new_process_group ...@@ -35,9 +36,10 @@ from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
def copy_op_with_new_input_output(block, src_op, **kwargs): def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
dist_op_desc = block.desc.append_op() dist_op_desc = block.desc.append_op()
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs assert input_name in kwargs
dist_op_desc.set_input(input_name, kwargs[input_name]) dist_op_desc.set_input(input_name, kwargs[input_name])
...@@ -253,7 +255,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -253,7 +255,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
new_kwargs = copy.deepcopy(kwargs) new_kwargs = copy.deepcopy(kwargs)
new_kwargs['Out@GRAD'] = [intermediate_var_0.name] new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output( matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs) ctx, main_block, backward_op, **new_kwargs)
else: else:
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0 assert Y_var_dim_mapping[0] < 0
...@@ -281,7 +283,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -281,7 +283,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
new_kwargs['X@GRAD'] = [intermediate_var_0.name] new_kwargs['X@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output( matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs) ctx, main_block, backward_op, **new_kwargs)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad # NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if has_x_grad: if has_x_grad:
...@@ -304,8 +306,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -304,8 +306,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
X_grad_dist_attr, ctx) X_grad_dist_attr, ctx)
else: else:
# replicate # replicate
matmul_op_desc = copy_op_with_new_input_output(main_block, backward_op, matmul_op_desc = copy_op_with_new_input_output(ctx, main_block,
**kwargs) backward_op, **kwargs)
main_block._sync_with_cpp() main_block._sync_with_cpp()
......
...@@ -22,6 +22,7 @@ from ..utils import is_valid_list_index ...@@ -22,6 +22,7 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
...@@ -181,6 +182,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -181,6 +182,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# create op # create op
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(src_op.desc) new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list) new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name]) new_op_desc.set_input('X', [X_var.name])
...@@ -345,6 +347,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -345,6 +347,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# create op # create op
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(src_op.desc) new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list) new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name]) new_op_desc.set_input('X', [X_var.name])
......
...@@ -39,6 +39,7 @@ from .process_group import get_world_process_groups ...@@ -39,6 +39,7 @@ from .process_group import get_world_process_groups
from .process_group import _g_process_group_map, ProcessGroup from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .cluster import Cluster from .cluster import Cluster
......
...@@ -24,6 +24,7 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op ...@@ -24,6 +24,7 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
...@@ -363,8 +364,9 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -363,8 +364,9 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
def _get_dist_op_backward_implement(backward_op, dist_context, def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op_id2forward_op): forward_op_id2forward_op):
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
if backward_op.desc.id() in dist_op_context.gradopidx2opidx: if backward_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
forward_op_id = dist_op_context.gradopidx2opidx[backward_op.desc.id()] forward_op_id = dist_op_context.grad_op_id_to_op_id[backward_op.desc.id(
)]
forward_op = forward_op_id2forward_op[forward_op_id] forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
......
...@@ -1402,3 +1402,19 @@ def get_standalone_cost_data(distributed_programs): ...@@ -1402,3 +1402,19 @@ def get_standalone_cost_data(distributed_programs):
standalone_cost_data.append(cost_data) standalone_cost_data.append(cost_data)
return standalone_cost_data return standalone_cost_data
def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context):
op_id = op_desc.id()
op_original_id = op_desc.original_id()
# First, try to set the original id to the id of the op_desc
if op_id in dist_context._dist_ops_for_program:
dist_op_desc.set_original_id(op_id)
return
# Second, try to set the original id to the original_id of the op_desc
elif op_original_id in dist_context._dist_ops_for_program:
dist_op_desc.set_original_id(op_original_id)
return
# Third, print error infomation if we cannot find the original id
else:
assert False, "Cannot find the original id in the distributed context"
...@@ -202,7 +202,7 @@ class ProgramStats(object): ...@@ -202,7 +202,7 @@ class ProgramStats(object):
if op.desc.has_attr(op_device_attr_name): if op.desc.has_attr(op_device_attr_name):
op_device = op.desc.attr(op_device_attr_name) op_device = op.desc.attr(op_device_attr_name)
# Setting the force_cpu of seed to true will make the output of seed in cpu memory, # Setting the force_cpu of seed to true will make the output of seed in cpu memory,
# reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang # reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang
added_op = self.block._insert_op( added_op = self.block._insert_op(
index=op.idx, index=op.idx,
...@@ -957,7 +957,7 @@ def _append_backward_ops_with_checkpoints_( ...@@ -957,7 +957,7 @@ def _append_backward_ops_with_checkpoints_(
# added_descs should be in grad_op_descs because it is backward op desc # added_descs should be in grad_op_descs because it is backward op desc
grad_op_descs.extend(buffer_descs) grad_op_descs.extend(buffer_descs)
# 3.c. add backward ops for all ops in current segment # 3.c. add backward ops for all ops in current segment
for op_desc in reversed(added_descs): for op_desc in reversed(added_descs):
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op_desc, cpt.to_text(no_grad_dict[block.idx]), []) op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
...@@ -1109,10 +1109,11 @@ def _append_backward_ops_(block, ...@@ -1109,10 +1109,11 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# Build the mapping between the forward op and bacckward op (Only for auto parallel)
if distop_context is not None: if distop_context is not None:
for op_desc in grad_op_desc: for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.gradopidx2opidx assert op_desc.id() not in distop_context.grad_op_id_to_op_id
distop_context.gradopidx2opidx[op_desc.id()] = op.desc.id() distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id()
# Set device for grad_op according to forward Op # Set device for grad_op according to forward Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
...@@ -1197,6 +1198,12 @@ def _append_backward_ops_(block, ...@@ -1197,6 +1198,12 @@ def _append_backward_ops_(block,
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
# Rebuild the mapping because new_op_desc has a differnt id (Only for auto parallel)
if distop_context is not None:
if op_desc.id() in distop_context.grad_op_id_to_op_id:
distop_context.grad_op_id_to_op_id[new_op_desc.id(
)] = distop_context.grad_op_id_to_op_id[op_desc.id()]
distop_context.grad_op_id_to_op_id.pop(op_desc.id())
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
grad_to_var["__current_op_desc__"] = new_op_desc grad_to_var["__current_op_desc__"] = new_op_desc
if callbacks is not None: if callbacks is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册