提交 b3ea677a 编写于 作者: F fengjiayi

update

上级 f2e76008
...@@ -243,6 +243,7 @@ void BindOpDesc(py::module &m) { ...@@ -243,6 +243,7 @@ void BindOpDesc(py::module &m) {
.def("set_input", &OpDescBind::SetInput) .def("set_input", &OpDescBind::SetInput)
.def("output", &OpDescBind::Output) .def("output", &OpDescBind::Output)
.def("output_names", &OpDescBind::OutputNames) .def("output_names", &OpDescBind::OutputNames)
.def("output_arg_names", &OpDescBind::OutputArgumentNames)
.def("set_output", &OpDescBind::SetOutput) .def("set_output", &OpDescBind::SetOutput)
.def("has_attr", &OpDescBind::HasAttr) .def("has_attr", &OpDescBind::HasAttr)
.def("attr_type", &OpDescBind::GetAttrType) .def("attr_type", &OpDescBind::GetAttrType)
......
...@@ -282,7 +282,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -282,7 +282,7 @@ All parameter, weight, gradient are variables in Paddle.
} }
return ret_values; return ret_values;
}); });
m.def("get_grad_op_descs", m.def("get_grad_op_desc",
[](const OpDescBind &op_desc, [](const OpDescBind &op_desc,
const std::unordered_set<std::string> &no_grad_set, const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> &grad_to_var, std::unordered_map<std::string, std::string> &grad_to_var,
......
from paddle.v2.fluid import framework as framework from paddle.v2.fluid import framework as framework
from . import core from . import core
import collections
__all__ = ['append_backward_ops'] __all__ = ['append_backward_ops']
...@@ -20,6 +21,20 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback): ...@@ -20,6 +21,20 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
no_grad_set[block.idx], no_grad_set[block.idx],
grad_to_var, grad_sub_block_list) grad_to_var, grad_sub_block_list)
grad_op_descs.append(grad_op_desc) grad_op_descs.append(grad_op_desc)
# grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
# flatten grad_op_descs
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
output_vars = collections.defaultdict(list)
for pos, op_desc in enumerate(grad_op_descs):
for var_name in op_desc.output_arg_names():
output_vars[var_name].append(pos)
for var_name, poses in output_vars.iteritems():
if len(poses) == 1:
continue
renamed_list = []
for pos in reversed(sorted(poses)):
new_name = var_name + "@RENAMED@" + len(renamed_list)
def append_backward_ops(loss, parameter_list=None, no_grad_set=None): def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册