From b3ea677a2b7c052793e242bc8a699cea34257201 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 18 Dec 2017 11:25:18 +0800 Subject: [PATCH] update --- paddle/pybind/protobuf.cc | 1 + paddle/pybind/pybind.cc | 2 +- python/paddle/v2/fluid/backward.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 6c8f06cccb..f67aa4a81e 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -243,6 +243,7 @@ void BindOpDesc(py::module &m) { .def("set_input", &OpDescBind::SetInput) .def("output", &OpDescBind::Output) .def("output_names", &OpDescBind::OutputNames) + .def("output_arg_names", &OpDescBind::OutputArgumentNames) .def("set_output", &OpDescBind::SetOutput) .def("has_attr", &OpDescBind::HasAttr) .def("attr_type", &OpDescBind::GetAttrType) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 1faf24bcb8..cd4887d63b 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -282,7 +282,7 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); - m.def("get_grad_op_descs", + m.def("get_grad_op_desc", [](const OpDescBind &op_desc, const std::unordered_set &no_grad_set, std::unordered_map &grad_to_var, diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 3a128b8e61..1756f1a7af 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -1,5 +1,6 @@ from paddle.v2.fluid import framework as framework from . import core +import collections __all__ = ['append_backward_ops'] @@ -20,6 +21,20 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback): no_grad_set[block.idx], grad_to_var, grad_sub_block_list) grad_op_descs.append(grad_op_desc) + # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...] + # flatten grad_op_descs + grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ????? + + output_vars = collections.defaultdict(list) + for pos, op_desc in enumerate(grad_op_descs): + for var_name in op_desc.output_arg_names(): + output_vars[var_name].append(pos) + for var_name, poses in output_vars.iteritems(): + if len(poses) == 1: + continue + renamed_list = [] + for pos in reversed(sorted(poses)): + new_name = var_name + "@RENAMED@" + len(renamed_list) def append_backward_ops(loss, parameter_list=None, no_grad_set=None): -- GitLab