提交 590e6111 编写于 作者: F fengjiayi

update

上级 6bb4a6fd
......@@ -157,6 +157,7 @@ void BindBlockDesc(py::module &m) {
.def_property_readonly("parent", &BlockDescBind::Parent)
.def("append_op", &BlockDescBind::AppendOp,
py::return_value_policy::reference)
.def("append_allocated_op", &BlockDescBind::AppendAllocatedOp)
.def("prepend_op", &BlockDescBind::PrependOp,
py::return_value_policy::reference)
.def("var",
......
......@@ -15,7 +15,11 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
op_desc_list[i].rename_output(old_name, new_name)
def backward_impl(block, target_block, no_grad_set, callback=None):
def backward_impl(block,
target_block,
no_grad_set,
grad_info_map,
callback=None):
grad_op_descs = []
grad_to_var = {}
program = block.program
......@@ -25,7 +29,8 @@ def backward_impl(block, target_block, no_grad_set, callback=None):
sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
backward_impl(sub_block, grad_sub_block, no_grad_set, callback)
backward_impl(sub_block, grad_sub_block, no_grad_set, grad_info_map,
callback)
grad_sub_block_list.append(grad_sub_block)
grad_op_desc = core.get_grad_op_desc(each_op.desc,
no_grad_set[block.idx],
......@@ -71,17 +76,28 @@ def backward_impl(block, target_block, no_grad_set, callback=None):
pending_sum_ops.append((core.OpDesc(
type="sum_op", inputs=inputs, outputs=var_name, attrs={}),
len(grad_op_descs)))
# TODO: remove op in no grad set
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
for p in reversed(pending_sum_ops):
grad_op_descs.insert(p[1], p[0])
# create new gradient variables in the target block
# create new gradient variables in the target block desc
for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names():
if target_block.has_var(
if target_block.desc.has_var(
grad_var_name) or grad_var_name == core.get_empty_var_name(
):
continue
target_block.var(grad_var_name)
target_block.desc.var(grad_var_name)
if not grad_to_var.has_key(grad_var_name):
continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
target_block)
# insert backward operators to target_block
for op_desc in grad_op_descs:
target_block.desc.append_allocated_op(op_desc)
target_block.sync_with_cpp()
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册