提交 a4b17225 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #4809 from jacquesqiao/backward-return-map

Backward return map
...@@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector<std::string>& names,
} }
static void CreateGradVarInBlock( static void CreateGradVarInBlock(
std::unordered_map<std::string, GradVarInfo>* grad_var_record, size_t grad_op_start_index,
BlockDescBind* block_desc, size_t grad_op_start_index, const std::unordered_map<std::string, std::string>& param_name_map,
const std::unordered_map<std::string, std::string>& param_name_map) { BlockDescBind* block_desc,
std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
auto ops = block_desc->AllOps(); auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size(); for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) { ++op_index) {
...@@ -422,8 +423,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -422,8 +423,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs; return backward_descs;
} }
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/> ParamGradInfoMap AppendBackward(
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names; std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.reserve(no_grad_vars.size() + 1);
...@@ -461,11 +462,11 @@ AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, ...@@ -461,11 +462,11 @@ AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
root_block->Var(fill_one_op_out); root_block->Var(fill_one_op_out);
// create grad_var for all blocks in this program // create grad_var for all blocks in this program
CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var); CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
for (size_t block_index = forward_block_num; for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) { block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0, CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
grad_to_var); &retv);
} }
return retv; return retv;
} }
......
...@@ -36,10 +36,11 @@ struct GradVarInfo { ...@@ -36,10 +36,11 @@ struct GradVarInfo {
int op_idx_; int op_idx_;
}; };
// TODO(jiayi): Add target as parameter and generate backward op using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/,
// according to target. GradVarInfo /*grad_var_info*/>;
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework } // namespace framework
......
...@@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) { ...@@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) {
.def("append_backward", .def("append_backward",
[](ProgramDescBind &program_desc, const VarDescBind &target, [](ProgramDescBind &program_desc, const VarDescBind &target,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
ParamGradInfoMap param_grad_map =
AppendBackward(program_desc, target, no_grad_vars); AppendBackward(program_desc, target, no_grad_vars);
std::unordered_map<
std::string, std::tuple<std::string /* grad_var_name */,
int /* block_idx */, int /* op_idx */>>
retv;
for (auto it = param_grad_map.begin(); it != param_grad_map.end();
++it) {
const auto &grad_info = it->second;
retv[it->first] = std::make_tuple(
grad_info.name_, grad_info.block_idx_, grad_info.op_idx_);
}
return retv;
}) })
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size) .def("num_blocks", &ProgramDescBind::Size)
......
...@@ -57,11 +57,18 @@ class TestProgram(unittest.TestCase): ...@@ -57,11 +57,18 @@ class TestProgram(unittest.TestCase):
"mul", "elementwise_add", "fill_constant", "elementwise_add_grad", "mul", "elementwise_add", "fill_constant", "elementwise_add_grad",
"mul_grad" "mul_grad"
] ]
def grad_name(name):
return name + "@GRAD"
actual_ops = [] actual_ops = []
prog.append_backward(target, set()) param_to_grad = prog.append_backward(target, set())
for var_name in ("x1", "y1", "out1", "b1"):
self.assertEqual(param_to_grad[var_name][0], grad_name(var_name))
self.assertEqual(param_to_grad[var_name][1], 0)
for op in block.all_ops(): for op in block.all_ops():
actual_ops.append(op.type()) actual_ops.append(op.type())
print(actual_ops)
self.assertEqual(actual_ops, expect_ops) self.assertEqual(actual_ops, expect_ops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册