提交 a4b17225 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #4809 from jacquesqiao/backward-return-map

Backward return map
......@@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector<std::string>& names,
}
static void CreateGradVarInBlock(
std::unordered_map<std::string, GradVarInfo>* grad_var_record,
BlockDescBind* block_desc, size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map) {
size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map,
BlockDescBind* block_desc,
std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) {
......@@ -422,9 +423,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs;
}
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) {
ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1);
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
......@@ -461,11 +462,11 @@ AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
root_block->Var(fill_one_op_out);
// create grad_var for all blocks in this program
CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var);
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0,
grad_to_var);
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
&retv);
}
return retv;
}
......
......@@ -36,11 +36,12 @@ struct GradVarInfo {
int op_idx_;
};
// TODO(jiayi): Add target as parameter and generate backward op
// according to target.
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars);
using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/,
GradVarInfo /*grad_var_info*/>;
ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
} // namespace paddle
......@@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) {
.def("append_backward",
[](ProgramDescBind &program_desc, const VarDescBind &target,
const std::unordered_set<std::string> &no_grad_vars) {
AppendBackward(program_desc, target, no_grad_vars);
ParamGradInfoMap param_grad_map =
AppendBackward(program_desc, target, no_grad_vars);
std::unordered_map<
std::string, std::tuple<std::string /* grad_var_name */,
int /* block_idx */, int /* op_idx */>>
retv;
for (auto it = param_grad_map.begin(); it != param_grad_map.end();
++it) {
const auto &grad_info = it->second;
retv[it->first] = std::make_tuple(
grad_info.name_, grad_info.block_idx_, grad_info.op_idx_);
}
return retv;
})
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size)
......
......@@ -57,11 +57,18 @@ class TestProgram(unittest.TestCase):
"mul", "elementwise_add", "fill_constant", "elementwise_add_grad",
"mul_grad"
]
def grad_name(name):
return name + "@GRAD"
actual_ops = []
prog.append_backward(target, set())
param_to_grad = prog.append_backward(target, set())
for var_name in ("x1", "y1", "out1", "b1"):
self.assertEqual(param_to_grad[var_name][0], grad_name(var_name))
self.assertEqual(param_to_grad[var_name][1], 0)
for op in block.all_ops():
actual_ops.append(op.type())
print(actual_ops)
self.assertEqual(actual_ops, expect_ops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册