未验证 提交 5620214e 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto parallel] Make sure the id semantics of every var and op unique (#38132)

* [Auto parallel] Make the id of var and op unique

* [Auto Parallel] Rename back dist_context to distop_context
上级 ccf99b66
......@@ -122,6 +122,10 @@ class Node {
// Please don't use this API!
int id() const { return id_; }
// Only use this for auto parallel.
// A node does not have original desc if the return is zero.
uint64_t OriginalDescId() const { return original_desc_id_; }
bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; }
bool IsCtrlVar() const {
......@@ -239,6 +243,10 @@ class Node {
int desc_order_;
int block_id_{-1};
// Store the original id of var desc or op desc.
// Only use this for auto parallel.
uint64_t original_desc_id_{0};
private:
// ID can only set by a Graph.
void SetId(int id) { id_ = id; }
......@@ -267,14 +275,16 @@ class Node {
op_desc_(nullptr),
type_(Type::kVariable),
desc_order_(NO_DESC_ORDER),
block_id_(block_id) {}
block_id_(block_id),
original_desc_id_(var_desc->OriginalId()) {}
explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()),
var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation),
desc_order_(NO_DESC_ORDER) {}
desc_order_(NO_DESC_ORDER),
original_desc_id_(op_desc->OriginalId()) {}
Node() = delete;
......
......@@ -352,15 +352,9 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
inputs_ = op_desc.inputs_;
outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_;
// The record of original_id_ is only for auto parallel.
original_id_ = op_desc.original_id_;
need_update_ = true;
// When creating graph from program, the creation of op node will create a new
// OpDesc instead of
// referring to the original one. To find the original OpDesc of the op node,
// the id have to be
// copied to the new OpDesc. The var node has the same situation, but the
// default copy constructor
// can copy the id automatically.
id_ = op_desc.id_;
}
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
......
......@@ -154,17 +154,10 @@ class OpDesc {
const BlockDesc *Block() const { return this->block_; }
// This thread-safe implementation seems to be redudent since the neural
// networks
// are usually constructed in a single thread
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> id{0};
return ++id;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
// The Id() and OrignalId() are only used for auto parallel.
uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
private:
template <typename MapType>
......@@ -177,6 +170,14 @@ class OpDesc {
return ret_val;
}
// This thread-safe implementation seems to be redudent since the neural
// networks are usually constructed in a single thread
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
// Must start from one
return ++uid;
}
proto::OpDesc desc_;
BlockDesc *block_{nullptr}; // not_own
// input arg name => input variable names
......@@ -189,7 +190,13 @@ class OpDesc {
// local changes should be synchronized, need_update_ should be set to true.
bool need_update_{false};
// Note: the id_ is unique (only for auto parallel).
uint64_t id_ = GenerateId();
// Note: the orignal_id_ is used for referring to the original OpDesc
// that the current OpDesc is built from (only for auto parallel).
// The default original_id_ is same as the id_, which means the
// current OpDesc is not built from the other one.
uint64_t original_id_ = id_;
};
} // namespace framework
} // namespace paddle
......@@ -69,6 +69,12 @@ class VarDesc {
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
// Explicitly implement the copy constructor for auto parallel
VarDesc(const VarDesc &other)
: desc_(other.desc_),
attrs_(other.attrs_),
original_id_(other.original_id_) {}
proto::VarDesc *Proto() { return &desc_; }
const proto::VarDesc *Proto() const { return &desc_; }
......@@ -153,16 +159,10 @@ class VarDesc {
Attribute GetAttr(const std::string &name) const;
// This thread-safe implementation seems to be redudent since the neural
// networks are usually constructed in a single thread.
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
return ++uid;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
// The Id() and OriginalId() are only used for auto parallel.
uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
private:
const proto::VarType::TensorDesc &tensor_desc() const;
......@@ -170,9 +170,23 @@ class VarDesc {
proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
// This thread-safe implementation seems to be redudent since the neural
// networks are usually constructed in a single thread.
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
return ++uid;
}
proto::VarDesc desc_;
AttributeMap attrs_;
// Note: the id_ is unique for all VarDesc (only for auto parallel).
uint64_t id_ = GenerateId();
// Note: the orignal_id_ is used for referring to the original VarDesc
// that the current VarDesc is built from (only for auto parallel).
// The default original_id_ is same as the id_, which means the
// current VarDesc is not built from the other one.
uint64_t original_id_ = id_;
};
bool operator==(const VarDesc &left, const VarDesc &right);
......
......@@ -143,6 +143,7 @@ void BindNode(py::module *m) {
.def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id)
.def("original_desc_id", &Node::OriginalDescId)
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
......
......@@ -208,6 +208,8 @@ void BindVarDsec(pybind11::module *m) {
.def("_set_attr", &pd::VarDesc::SetAttr)
.def("remove_attr", &pd::VarDesc::RemoveAttr)
.def("id", &pd::VarDesc::Id)
.def("original_id", &pd::VarDesc::OriginalId)
.def("set_original_id", &pd::VarDesc::SetOriginalId)
.def("attr", &pd::VarDesc::GetAttr);
pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
......@@ -305,6 +307,8 @@ void BindOpDesc(pybind11::module *m) {
.def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference)
.def("id", &pd::OpDesc::Id)
.def("original_id", &pd::OpDesc::OriginalId)
.def("set_original_id", &pd::OpDesc::SetOriginalId)
.def("inputs", &pd::OpDesc::Inputs)
.def("outputs", &pd::OpDesc::Outputs);
}
......
......@@ -698,13 +698,13 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
continue
# complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in gradopidx2opidx
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_context.gradopidx2opidx:
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops[:first_backward_op_idx],
dist_op_context.gradopidx2opidx[grad_op.desc.id()])
dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
assert forward_op is not None
# op dist attr
......@@ -769,7 +769,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx
# only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id
else:
assert grad_op.type == "sum", "got unexpect op [{}]".format(
str(grad_op.type))
......
......@@ -46,14 +46,19 @@ class DistributedContext:
"""
def __init__(self, program=None):
# Program related data members
self._serial_program = program
self._serial_graph = None
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._dist_tensors_for_program = {}
self._dist_ops_for_program = {}
# Graph related data members
self._is_initialized_for_graph = False
self._serial_graph = None
self._dist_tensors_for_graph = {}
self._dist_ops_for_graph = {}
self._node_id_to_tensor_id = {}
self._node_id_to_op_id = {}
# Other data members
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
......@@ -97,19 +102,38 @@ class DistributedContext:
def get_dist_tensor_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_tensors_for_program.get(serial_tensor_id, None)
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
if dist_tensor:
return dist_tensor
else:
serial_tensor_id = serial_tensor.desc.original_id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id,
None)
if dist_tensor:
return dist_tensor
else:
return None
def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_ops_for_program.get(serial_tensor_id, None)
def get_dist_op_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op
else:
serial_op_id = serial_op.desc.original_id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op
else:
return None
def get_dist_op_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_ops_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
return self._dist_ops_for_graph.get(serial_op_node_id, None)
def get_tensor_dist_attr_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
......@@ -117,7 +141,13 @@ class DistributedContext:
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
serial_tensor_id = serial_tensor.desc.original_id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id,
None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr)
......@@ -132,25 +162,18 @@ class DistributedContext:
else:
return None
def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
assert serial_tensor_node.is_var() and \
serial_tensor_node.var() is not None
serial_tensor_id = serial_tensor_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
assert dist_tensor is not None, \
"The distributed tensor of the program has not been added to this context."
serial_tensor_node_id = serial_tensor_node.id()
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_attr)
self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
serial_op_id = serial_op.desc.original_id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr)
......@@ -164,17 +187,6 @@ class DistributedContext:
else:
return None
def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
assert serial_op_node.is_op() and \
serial_op_node.op() is not None
serial_op_id = serial_op_node.op().id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
assert dist_op is not None, \
"The distributed operator of the program has not been added to this context."
serial_op_node_id = serial_op_node.id()
new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
......@@ -216,20 +228,36 @@ class DistributedContext:
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
dist_tensor = self._dist_tensors_for_program.get(tensor_id,
None)
dist_tensor = None
tensor_id = node.node.original_desc_id()
for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items(
):
if tensor_id == cur_tensor_id \
or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
dist_tensor = cur_dist_tensor
self._node_id_to_tensor_id[node.id()] = cur_tensor_id
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
self.set_tensor_dist_attr_for_graph(node, dist_tensor.dist_attr)
serial_tensor_node_id = node.id()
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_tensor.dist_attr)
self._dist_tensors_for_graph[
serial_tensor_node_id] = new_dist_tensor
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
dist_op = self._dist_ops_for_program.get(op_id, None)
dist_op = None
op_id = node.node.original_desc_id()
for cur_op_id, cur_dist_op in self._dist_ops_for_program.items(
):
if op_id == cur_op_id \
or op_id == cur_dist_op.serial_op.desc.original_id():
dist_op = cur_dist_op
self._node_id_to_op_id[node.id()] = cur_op_id
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
self.set_op_dist_attr_for_graph(node, dist_op.dist_attr)
serial_op_node_id = node.id()
new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
self._is_initialized_for_graph = True
def clear_dist_info_for_program(self):
......@@ -247,9 +275,8 @@ class DistributedContext:
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
tensor_id = self._node_id_to_tensor_id[node.id()]
updated = updated_tensors.get(tensor_id, False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph(
......@@ -257,10 +284,9 @@ class DistributedContext:
dist_tensor_for_program = self._dist_tensors_for_program[
tensor_id]
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_desc.name()] = True
updated_tensors[tensor_id] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_id = self._node_id_to_op_id[node.id()]
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
......@@ -360,7 +386,7 @@ class DistributedOperatorContext:
self._rank_id = None
self._cur_src_op = None
self._cur_dist_attr = None
self.gradopidx2opidx = {}
self.grad_op_id_to_op_id = {}
self.already_init_sync_vars = set()
def __deepcopy__(self, memo):
......
......@@ -22,6 +22,7 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
......@@ -86,6 +87,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
......@@ -170,6 +172,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
for input_name in backward_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in backward_op.desc.output_names():
......
......@@ -25,6 +25,7 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
......@@ -35,9 +36,10 @@ from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
def copy_op_with_new_input_output(block, src_op, **kwargs):
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
dist_op_desc = block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
assert input_name in kwargs
dist_op_desc.set_input(input_name, kwargs[input_name])
......@@ -253,7 +255,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
new_kwargs = copy.deepcopy(kwargs)
new_kwargs['Out@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs)
ctx, main_block, backward_op, **new_kwargs)
else:
# col parallel: matmul + allreduce
assert Y_var_dim_mapping[0] < 0
......@@ -281,7 +283,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
new_kwargs['X@GRAD'] = [intermediate_var_0.name]
matmul_op_desc = copy_op_with_new_input_output(
main_block, backward_op, **new_kwargs)
ctx, main_block, backward_op, **new_kwargs)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if has_x_grad:
......@@ -304,8 +306,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
X_grad_dist_attr, ctx)
else:
# replicate
matmul_op_desc = copy_op_with_new_input_output(main_block, backward_op,
**kwargs)
matmul_op_desc = copy_op_with_new_input_output(ctx, main_block,
backward_op, **kwargs)
main_block._sync_with_cpp()
......
......@@ -22,6 +22,7 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
......@@ -181,6 +182,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# create op
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name])
......@@ -345,6 +347,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# create op
new_op_desc = main_block.desc.append_op()
new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name])
......
......@@ -39,6 +39,7 @@ from .process_group import get_world_process_groups
from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard
from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .cluster import Cluster
......
......@@ -24,6 +24,7 @@ from paddle.distributed.auto_parallel.operators.common import get_distributed_op
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op
from .operators.common import BACKWARD_ONLY_DIST_OPS
......@@ -363,8 +364,9 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op_id2forward_op):
dist_op_context = dist_context.dist_op_context
if backward_op.desc.id() in dist_op_context.gradopidx2opidx:
forward_op_id = dist_op_context.gradopidx2opidx[backward_op.desc.id()]
if backward_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
forward_op_id = dist_op_context.grad_op_id_to_op_id[backward_op.desc.id(
)]
forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
......
......@@ -1402,3 +1402,19 @@ def get_standalone_cost_data(distributed_programs):
standalone_cost_data.append(cost_data)
return standalone_cost_data
def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context):
op_id = op_desc.id()
op_original_id = op_desc.original_id()
# First, try to set the original id to the id of the op_desc
if op_id in dist_context._dist_ops_for_program:
dist_op_desc.set_original_id(op_id)
return
# Second, try to set the original id to the original_id of the op_desc
elif op_original_id in dist_context._dist_ops_for_program:
dist_op_desc.set_original_id(op_original_id)
return
# Third, print error infomation if we cannot find the original id
else:
assert False, "Cannot find the original id in the distributed context"
......@@ -202,7 +202,7 @@ class ProgramStats(object):
if op.desc.has_attr(op_device_attr_name):
op_device = op.desc.attr(op_device_attr_name)
# Setting the force_cpu of seed to true will make the output of seed in cpu memory,
# Setting the force_cpu of seed to true will make the output of seed in cpu memory,
# reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang
added_op = self.block._insert_op(
index=op.idx,
......@@ -957,7 +957,7 @@ def _append_backward_ops_with_checkpoints_(
# added_descs should be in grad_op_descs because it is backward op desc
grad_op_descs.extend(buffer_descs)
# 3.c. add backward ops for all ops in current segment
# 3.c. add backward ops for all ops in current segment
for op_desc in reversed(added_descs):
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
......@@ -1109,10 +1109,11 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# Build the mapping between the forward op and bacckward op (Only for auto parallel)
if distop_context is not None:
for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.gradopidx2opidx
distop_context.gradopidx2opidx[op_desc.id()] = op.desc.id()
assert op_desc.id() not in distop_context.grad_op_id_to_op_id
distop_context.grad_op_id_to_op_id[op_desc.id()] = op.desc.id()
# Set device for grad_op according to forward Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
......@@ -1197,6 +1198,12 @@ def _append_backward_ops_(block,
for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
# Rebuild the mapping because new_op_desc has a differnt id (Only for auto parallel)
if distop_context is not None:
if op_desc.id() in distop_context.grad_op_id_to_op_id:
distop_context.grad_op_id_to_op_id[new_op_desc.id(
)] = distop_context.grad_op_id_to_op_id[op_desc.id()]
distop_context.grad_op_id_to_op_id.pop(op_desc.id())
new_op_desc._set_attr(op_role_attr_name, backward)
grad_to_var["__current_op_desc__"] = new_op_desc
if callbacks is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册