提交 6bb4a6fd 编写于 作者: F fengjiayi

update

上级 b3ea677a
...@@ -236,15 +236,25 @@ void BindOpDesc(py::module &m) { ...@@ -236,15 +236,25 @@ void BindOpDesc(py::module &m) {
.value("BLOCK", AttrType::BLOCK); .value("BLOCK", AttrType::BLOCK);
py::class_<OpDescBind> op_desc(m, "OpDesc", ""); py::class_<OpDescBind> op_desc(m, "OpDesc", "");
op_desc.def("type", &OpDescBind::Type) op_desc
.def("__init__",
[](OpDescBind &self, const std::string &type,
const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs) {
new (&self) OpDescBind(type, inputs, outputs, attrs);
})
.def("type", &OpDescBind::Type)
.def("set_type", &OpDescBind::SetType) .def("set_type", &OpDescBind::SetType)
.def("input", &OpDescBind::Input) .def("input", &OpDescBind::Input)
.def("input_names", &OpDescBind::InputNames) .def("input_names", &OpDescBind::InputNames)
.def("set_input", &OpDescBind::SetInput)
.def("output", &OpDescBind::Output) .def("output", &OpDescBind::Output)
.def("output_names", &OpDescBind::OutputNames) .def("output_names", &OpDescBind::OutputNames)
.def("output_arg_names", &OpDescBind::OutputArgumentNames) .def("set_input", &OpDescBind::SetInput)
.def("set_output", &OpDescBind::SetOutput) .def("set_output", &OpDescBind::SetOutput)
.def("input_arg_names", &OpDescBind::InputArgumentNames)
.def("output_arg_names", &OpDescBind::OutputArgumentNames)
.def("rename_input", &OpDescBind::RenameInput)
.def("rename_output", &OpDescBind::RenameOutput)
.def("has_attr", &OpDescBind::HasAttr) .def("has_attr", &OpDescBind::HasAttr)
.def("attr_type", &OpDescBind::GetAttrType) .def("attr_type", &OpDescBind::GetAttrType)
.def("attr_names", &OpDescBind::AttrNames) .def("attr_names", &OpDescBind::AttrNames)
......
...@@ -314,6 +314,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -314,6 +314,7 @@ All parameter, weight, gradient are variables in Paddle.
InferenceOptimize(*(origin.Proto()), &pruned_desc); InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc); return new ProgramDescBind(pruned_desc);
}); });
m.def("get_empty_var_name", []() { return framework::kEmptyVarName; });
m.def_submodule( m.def_submodule(
"var_names", "var_names",
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
......
...@@ -5,8 +5,19 @@ import collections ...@@ -5,8 +5,19 @@ import collections
__all__ = ['append_backward_ops'] __all__ = ['append_backward_ops']
def backward_impl(block, target_block, no_grad_set, grad_to_var, callback): def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
if begin_idx is None:
begin_idx = 0
if end_idx is None:
end_idx = len(op_desc_list)
for i in range(begin_idx, end_idx):
op_desc_list[i].rename_input(old_name, new_name)
op_desc_list[i].rename_output(old_name, new_name)
def backward_impl(block, target_block, no_grad_set, callback=None):
grad_op_descs = [] grad_op_descs = []
grad_to_var = {}
program = block.program program = block.program
for each_op in block.ops: for each_op in block.ops:
grad_sub_block_list = [] grad_sub_block_list = []
...@@ -14,8 +25,7 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback): ...@@ -14,8 +25,7 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
sub_block_idx = each_op.block_attr("sub_block") sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx) sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx) grad_sub_block = program.create_block(parent_idx=sub_block_idx)
backward_impl(sub_block, grad_sub_block, no_grad_set, grad_to_var, backward_impl(sub_block, grad_sub_block, no_grad_set, callback)
callback)
grad_sub_block_list.append(grad_sub_block) grad_sub_block_list.append(grad_sub_block)
grad_op_desc = core.get_grad_op_desc(each_op.desc, grad_op_desc = core.get_grad_op_desc(each_op.desc,
no_grad_set[block.idx], no_grad_set[block.idx],
...@@ -25,16 +35,53 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback): ...@@ -25,16 +35,53 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
# flatten grad_op_descs # flatten grad_op_descs
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ????? grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
output_vars = collections.defaultdict(list) pending_sum_ops = []
var_rename_count = collections.defaultdict(int)
var_inputs = collections.defaultdict(list)
for pos, op_desc in enumerate(grad_op_descs): for pos, op_desc in enumerate(grad_op_descs):
for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1:
pending_sum_ops.append((core.OpDesc(
type="sum_op",
inputs=var_inputs[var_name],
output=[var_name],
attrs={}), pos))
var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for var_name in op_desc.output_arg_names():
output_vars[var_name].append(pos) if len(var_inputs[var_name]) == 0:
for var_name, poses in output_vars.iteritems(): # it's the first time we get the variable
if len(poses) == 1: var_inputs[var_name] = var_name
else:
if len(var_inputs[var_name] == 1):
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1
# rename original var_name
var_inputs[var_name][0] = new_name
rename_arg(grad_op_descs, var_name, new_name, 0, pos)
rename_arg(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1
op_desc.rename_output(var_name, new_name)
var_inputs[var_name].append(new_name)
for var_name, inputs in var_inputs.iteritems():
if len(inputs) > 1:
pending_sum_ops.append((core.OpDesc(
type="sum_op", inputs=inputs, outputs=var_name, attrs={}),
len(grad_op_descs)))
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
for p in reversed(pending_sum_ops):
grad_op_descs.insert(p[1], p[0])
# create new gradient variables in the target block
for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names():
if target_block.has_var(
grad_var_name) or grad_var_name == core.get_empty_var_name(
):
continue continue
renamed_list = [] target_block.var(grad_var_name)
for pos in reversed(sorted(poses)):
new_name = var_name + "@RENAMED@" + len(renamed_list)
def append_backward_ops(loss, parameter_list=None, no_grad_set=None): def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册