提交 b3ea677a 编写于 作者: F fengjiayi

update

上级 f2e76008
......@@ -243,6 +243,7 @@ void BindOpDesc(py::module &m) {
.def("set_input", &OpDescBind::SetInput)
.def("output", &OpDescBind::Output)
.def("output_names", &OpDescBind::OutputNames)
.def("output_arg_names", &OpDescBind::OutputArgumentNames)
.def("set_output", &OpDescBind::SetOutput)
.def("has_attr", &OpDescBind::HasAttr)
.def("attr_type", &OpDescBind::GetAttrType)
......
......@@ -282,7 +282,7 @@ All parameter, weight, gradient are variables in Paddle.
}
return ret_values;
});
m.def("get_grad_op_descs",
m.def("get_grad_op_desc",
[](const OpDescBind &op_desc,
const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> &grad_to_var,
......
from paddle.v2.fluid import framework as framework
from . import core
import collections
__all__ = ['append_backward_ops']
......@@ -20,6 +21,20 @@ def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
no_grad_set[block.idx],
grad_to_var, grad_sub_block_list)
grad_op_descs.append(grad_op_desc)
# grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
# flatten grad_op_descs
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
output_vars = collections.defaultdict(list)
for pos, op_desc in enumerate(grad_op_descs):
for var_name in op_desc.output_arg_names():
output_vars[var_name].append(pos)
for var_name, poses in output_vars.iteritems():
if len(poses) == 1:
continue
renamed_list = []
for pos in reversed(sorted(poses)):
new_name = var_name + "@RENAMED@" + len(renamed_list)
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册