提交 6bb4a6fd 编写于 作者: F fengjiayi

update

上级 b3ea677a
......@@ -236,15 +236,25 @@ void BindOpDesc(py::module &m) {
.value("BLOCK", AttrType::BLOCK);
py::class_<OpDescBind> op_desc(m, "OpDesc", "");
op_desc.def("type", &OpDescBind::Type)
op_desc
.def("__init__",
[](OpDescBind &self, const std::string &type,
const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs) {
new (&self) OpDescBind(type, inputs, outputs, attrs);
})
.def("type", &OpDescBind::Type)
.def("set_type", &OpDescBind::SetType)
.def("input", &OpDescBind::Input)
.def("input_names", &OpDescBind::InputNames)
.def("set_input", &OpDescBind::SetInput)
.def("output", &OpDescBind::Output)
.def("output_names", &OpDescBind::OutputNames)
.def("output_arg_names", &OpDescBind::OutputArgumentNames)
.def("set_input", &OpDescBind::SetInput)
.def("set_output", &OpDescBind::SetOutput)
.def("input_arg_names", &OpDescBind::InputArgumentNames)
.def("output_arg_names", &OpDescBind::OutputArgumentNames)
.def("rename_input", &OpDescBind::RenameInput)
.def("rename_output", &OpDescBind::RenameOutput)
.def("has_attr", &OpDescBind::HasAttr)
.def("attr_type", &OpDescBind::GetAttrType)
.def("attr_names", &OpDescBind::AttrNames)
......
......@@ -314,6 +314,7 @@ All parameter, weight, gradient are variables in Paddle.
InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def("get_empty_var_name", []() { return framework::kEmptyVarName; });
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
......
......@@ -5,8 +5,19 @@ import collections
__all__ = ['append_backward_ops']
def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
if begin_idx is None:
begin_idx = 0
if end_idx is None:
end_idx = len(op_desc_list)
for i in range(begin_idx, end_idx):
op_desc_list[i].rename_input(old_name, new_name)
op_desc_list[i].rename_output(old_name, new_name)
def backward_impl(block, target_block, no_grad_set, callback=None):
grad_op_descs = []
grad_to_var = {}
program = block.program
for each_op in block.ops:
grad_sub_block_list = []
......@@ -14,8 +25,7 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
backward_impl(sub_block, grad_sub_block, no_grad_set, grad_to_var,
callback)
backward_impl(sub_block, grad_sub_block, no_grad_set, callback)
grad_sub_block_list.append(grad_sub_block)
grad_op_desc = core.get_grad_op_desc(each_op.desc,
no_grad_set[block.idx],
......@@ -25,16 +35,53 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
# flatten grad_op_descs
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
output_vars = collections.defaultdict(list)
pending_sum_ops = []
var_rename_count = collections.defaultdict(int)
var_inputs = collections.defaultdict(list)
for pos, op_desc in enumerate(grad_op_descs):
for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1:
pending_sum_ops.append((core.OpDesc(
type="sum_op",
inputs=var_inputs[var_name],
output=[var_name],
attrs={}), pos))
var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names():
output_vars[var_name].append(pos)
for var_name, poses in output_vars.iteritems():
if len(poses) == 1:
continue
renamed_list = []
for pos in reversed(sorted(poses)):
new_name = var_name + "@RENAMED@" + len(renamed_list)
if len(var_inputs[var_name]) == 0:
# it's the first time we get the variable
var_inputs[var_name] = var_name
else:
if len(var_inputs[var_name] == 1):
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1
# rename original var_name
var_inputs[var_name][0] = new_name
rename_arg(grad_op_descs, var_name, new_name, 0, pos)
rename_arg(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1
op_desc.rename_output(var_name, new_name)
var_inputs[var_name].append(new_name)
for var_name, inputs in var_inputs.iteritems():
if len(inputs) > 1:
pending_sum_ops.append((core.OpDesc(
type="sum_op", inputs=inputs, outputs=var_name, attrs={}),
len(grad_op_descs)))
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
for p in reversed(pending_sum_ops):
grad_op_descs.insert(p[1], p[0])
# create new gradient variables in the target block
for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names():
if target_block.has_var(
grad_var_name) or grad_var_name == core.get_empty_var_name(
):
continue
target_block.var(grad_var_name)
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册