提交 2befb9f9 编写于 作者: Q qiaolongfei

optimizer backward CreateGradVarInBlock input output order

上级 9a0ef7d2
...@@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector<std::string>& names,
} }
static void CreateGradVarInBlock( static void CreateGradVarInBlock(
std::unordered_map<std::string, GradVarInfo>* grad_var_record, size_t grad_op_start_index,
BlockDescBind* block_desc, size_t grad_op_start_index, const std::unordered_map<std::string, std::string>& param_name_map,
const std::unordered_map<std::string, std::string>& param_name_map) { BlockDescBind* block_desc,
std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
auto ops = block_desc->AllOps(); auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size(); for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) { ++op_index) {
...@@ -451,11 +452,11 @@ ParamGradInfoMap AppendBackward( ...@@ -451,11 +452,11 @@ ParamGradInfoMap AppendBackward(
root_block->NewVar(fill_one_op_out); root_block->NewVar(fill_one_op_out);
// create grad_var for all blocks in this program // create grad_var for all blocks in this program
CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var); CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
for (size_t block_index = forward_block_num; for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) { block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0, CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
grad_to_var); &retv);
} }
return retv; return retv;
} }
......
...@@ -69,7 +69,6 @@ class TestProgram(unittest.TestCase): ...@@ -69,7 +69,6 @@ class TestProgram(unittest.TestCase):
for op in block.all_ops(): for op in block.all_ops():
actual_ops.append(op.type()) actual_ops.append(op.type())
print(actual_ops)
self.assertEqual(actual_ops, expect_ops) self.assertEqual(actual_ops, expect_ops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册