提交 9a0ef7d2 编写于 作者: Q qiaolongfei

append_backward return map to python

上级 ec783d6b
......@@ -412,8 +412,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs;
}
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1);
......
......@@ -36,10 +36,11 @@ struct GradVarInfo {
int op_idx_;
};
// TODO(jiayi): Add target as parameter and generate backward op
// according to target.
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/,
GradVarInfo /*grad_var_info*/>;
ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
......
......@@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) {
.def("append_backward",
[](ProgramDescBind &program_desc, const VarDescBind &target,
const std::unordered_set<std::string> &no_grad_vars) {
ParamGradInfoMap param_grad_map =
AppendBackward(program_desc, target, no_grad_vars);
std::unordered_map<
std::string, std::tuple<std::string /* grad_var_name */,
int /* block_idx */, int /* op_idx */>>
retv;
for (auto it = param_grad_map.begin(); it != param_grad_map.end();
++it) {
const auto &grad_info = it->second;
retv[it->first] = std::make_tuple(
grad_info.name_, grad_info.block_idx_, grad_info.op_idx_);
}
return retv;
})
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size)
......
......@@ -57,8 +57,16 @@ class TestProgram(unittest.TestCase):
"mul", "elementwise_add", "fill_constant", "elementwise_add_grad",
"mul_grad"
]
def grad_name(name):
return name + "@GRAD"
actual_ops = []
prog.append_backward(target, set())
param_to_grad = prog.append_backward(target, set())
for var_name in ("x1", "y1", "out1", "b1"):
self.assertEqual(param_to_grad[var_name][0], grad_name(var_name))
self.assertEqual(param_to_grad[var_name][1], 0)
for op in block.all_ops():
actual_ops.append(op.type())
print(actual_ops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册