提交 590e6111 编写于 作者: F fengjiayi

update

上级 6bb4a6fd
...@@ -157,6 +157,7 @@ void BindBlockDesc(py::module &m) { ...@@ -157,6 +157,7 @@ void BindBlockDesc(py::module &m) {
.def_property_readonly("parent", &BlockDescBind::Parent) .def_property_readonly("parent", &BlockDescBind::Parent)
.def("append_op", &BlockDescBind::AppendOp, .def("append_op", &BlockDescBind::AppendOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_allocated_op", &BlockDescBind::AppendAllocatedOp)
.def("prepend_op", &BlockDescBind::PrependOp, .def("prepend_op", &BlockDescBind::PrependOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("var", .def("var",
......
...@@ -15,7 +15,11 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None): ...@@ -15,7 +15,11 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
op_desc_list[i].rename_output(old_name, new_name) op_desc_list[i].rename_output(old_name, new_name)
def backward_impl(block, target_block, no_grad_set, callback=None): def backward_impl(block,
target_block,
no_grad_set,
grad_info_map,
callback=None):
grad_op_descs = [] grad_op_descs = []
grad_to_var = {} grad_to_var = {}
program = block.program program = block.program
...@@ -25,7 +29,8 @@ def backward_impl(block, target_block, no_grad_set, callback=None): ...@@ -25,7 +29,8 @@ def backward_impl(block, target_block, no_grad_set, callback=None):
sub_block_idx = each_op.block_attr("sub_block") sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx) sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx) grad_sub_block = program.create_block(parent_idx=sub_block_idx)
backward_impl(sub_block, grad_sub_block, no_grad_set, callback) backward_impl(sub_block, grad_sub_block, no_grad_set, grad_info_map,
callback)
grad_sub_block_list.append(grad_sub_block) grad_sub_block_list.append(grad_sub_block)
grad_op_desc = core.get_grad_op_desc(each_op.desc, grad_op_desc = core.get_grad_op_desc(each_op.desc,
no_grad_set[block.idx], no_grad_set[block.idx],
...@@ -71,17 +76,28 @@ def backward_impl(block, target_block, no_grad_set, callback=None): ...@@ -71,17 +76,28 @@ def backward_impl(block, target_block, no_grad_set, callback=None):
pending_sum_ops.append((core.OpDesc( pending_sum_ops.append((core.OpDesc(
type="sum_op", inputs=inputs, outputs=var_name, attrs={}), type="sum_op", inputs=inputs, outputs=var_name, attrs={}),
len(grad_op_descs))) len(grad_op_descs)))
# TODO: remove op in no grad set
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的 # 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
for p in reversed(pending_sum_ops): for p in reversed(pending_sum_ops):
grad_op_descs.insert(p[1], p[0]) grad_op_descs.insert(p[1], p[0])
# create new gradient variables in the target block # create new gradient variables in the target block desc
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names(): for grad_var_name in op_desc.output_arg_names():
if target_block.has_var( if target_block.desc.has_var(
grad_var_name) or grad_var_name == core.get_empty_var_name( grad_var_name) or grad_var_name == core.get_empty_var_name(
): ):
continue continue
target_block.var(grad_var_name) target_block.desc.var(grad_var_name)
if not grad_to_var.has_key(grad_var_name):
continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
target_block)
# insert backward operators to target_block
for op_desc in grad_op_descs:
target_block.desc.append_allocated_op(op_desc)
target_block.sync_with_cpp()
def append_backward_ops(loss, parameter_list=None, no_grad_set=None): def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册