From 590e6111f164b559230273496c90ed1879b2dc47 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 19 Dec 2017 17:46:13 +0800 Subject: [PATCH] update --- paddle/pybind/protobuf.cc | 1 + python/paddle/v2/fluid/backward.py | 26 +++++++++++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index bb9872f9f7..d05eb94644 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -157,6 +157,7 @@ void BindBlockDesc(py::module &m) { .def_property_readonly("parent", &BlockDescBind::Parent) .def("append_op", &BlockDescBind::AppendOp, py::return_value_policy::reference) + .def("append_allocated_op", &BlockDescBind::AppendAllocatedOp) .def("prepend_op", &BlockDescBind::PrependOp, py::return_value_policy::reference) .def("var", diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index a399a9712d..5eb7794948 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -15,7 +15,11 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None): op_desc_list[i].rename_output(old_name, new_name) -def backward_impl(block, target_block, no_grad_set, callback=None): +def backward_impl(block, + target_block, + no_grad_set, + grad_info_map, + callback=None): grad_op_descs = [] grad_to_var = {} program = block.program @@ -25,7 +29,8 @@ def backward_impl(block, target_block, no_grad_set, callback=None): sub_block_idx = each_op.block_attr("sub_block") sub_block = program.block(sub_block_idx) grad_sub_block = program.create_block(parent_idx=sub_block_idx) - backward_impl(sub_block, grad_sub_block, no_grad_set, callback) + backward_impl(sub_block, grad_sub_block, no_grad_set, grad_info_map, + callback) grad_sub_block_list.append(grad_sub_block) grad_op_desc = core.get_grad_op_desc(each_op.desc, no_grad_set[block.idx], @@ -71,17 +76,28 @@ def backward_impl(block, target_block, no_grad_set, callback=None): pending_sum_ops.append((core.OpDesc( type="sum_op", inputs=inputs, outputs=var_name, attrs={}), len(grad_op_descs))) + # TODO: remove op in no grad set + # 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的 for p in reversed(pending_sum_ops): grad_op_descs.insert(p[1], p[0]) - # create new gradient variables in the target block + # create new gradient variables in the target block desc for op_desc in grad_op_descs: for grad_var_name in op_desc.output_arg_names(): - if target_block.has_var( + if target_block.desc.has_var( grad_var_name) or grad_var_name == core.get_empty_var_name( ): continue - target_block.var(grad_var_name) + target_block.desc.var(grad_var_name) + if not grad_to_var.has_key(grad_var_name): + continue + grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, + target_block) + # insert backward operators to target_block + for op_desc in grad_op_descs: + target_block.desc.append_allocated_op(op_desc) + + target_block.sync_with_cpp() def append_backward_ops(loss, parameter_list=None, no_grad_set=None): -- GitLab