提交 2befb9f9 编写于 作者: Q qiaolongfei

optimizer backward CreateGradVarInBlock input output order

上级 9a0ef7d2
......@@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector<std::string>& names,
}
static void CreateGradVarInBlock(
std::unordered_map<std::string, GradVarInfo>* grad_var_record,
BlockDescBind* block_desc, size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map) {
size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map,
BlockDescBind* block_desc,
std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) {
......@@ -451,11 +452,11 @@ ParamGradInfoMap AppendBackward(
root_block->NewVar(fill_one_op_out);
// create grad_var for all blocks in this program
CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var);
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0,
grad_to_var);
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
&retv);
}
return retv;
}
......
......@@ -69,7 +69,6 @@ class TestProgram(unittest.TestCase):
for op in block.all_ops():
actual_ops.append(op.type())
print(actual_ops)
self.assertEqual(actual_ops, expect_ops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册