提交 9a0ef7d2 编写于 作者: Q qiaolongfei

append_backward return map to python

上级 ec783d6b
...@@ -412,9 +412,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -412,9 +412,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs; return backward_descs;
} }
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/> ParamGradInfoMap AppendBackward(
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names; std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.reserve(no_grad_vars.size() + 1);
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
......
...@@ -36,11 +36,12 @@ struct GradVarInfo { ...@@ -36,11 +36,12 @@ struct GradVarInfo {
int op_idx_; int op_idx_;
}; };
// TODO(jiayi): Add target as parameter and generate backward op using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/,
// according to target. GradVarInfo /*grad_var_info*/>;
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, ParamGradInfoMap AppendBackward(
const std::unordered_set<std::string>& no_grad_vars); ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) { ...@@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) {
.def("append_backward", .def("append_backward",
[](ProgramDescBind &program_desc, const VarDescBind &target, [](ProgramDescBind &program_desc, const VarDescBind &target,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
AppendBackward(program_desc, target, no_grad_vars); ParamGradInfoMap param_grad_map =
AppendBackward(program_desc, target, no_grad_vars);
std::unordered_map<
std::string, std::tuple<std::string /* grad_var_name */,
int /* block_idx */, int /* op_idx */>>
retv;
for (auto it = param_grad_map.begin(); it != param_grad_map.end();
++it) {
const auto &grad_info = it->second;
retv[it->first] = std::make_tuple(
grad_info.name_, grad_info.block_idx_, grad_info.op_idx_);
}
return retv;
}) })
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size) .def("num_blocks", &ProgramDescBind::Size)
......
...@@ -57,8 +57,16 @@ class TestProgram(unittest.TestCase): ...@@ -57,8 +57,16 @@ class TestProgram(unittest.TestCase):
"mul", "elementwise_add", "fill_constant", "elementwise_add_grad", "mul", "elementwise_add", "fill_constant", "elementwise_add_grad",
"mul_grad" "mul_grad"
] ]
def grad_name(name):
return name + "@GRAD"
actual_ops = [] actual_ops = []
prog.append_backward(target, set()) param_to_grad = prog.append_backward(target, set())
for var_name in ("x1", "y1", "out1", "b1"):
self.assertEqual(param_to_grad[var_name][0], grad_name(var_name))
self.assertEqual(param_to_grad[var_name][1], 0)
for op in block.all_ops(): for op in block.all_ops():
actual_ops.append(op.type()) actual_ops.append(op.type())
print(actual_ops) print(actual_ops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册