From 2befb9f9722243fe405f8a0c491cdbe6bd029e93 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 14 Oct 2017 20:10:13 -0700 Subject: [PATCH] optimizer backward CreateGradVarInBlock input output order --- paddle/framework/backward.cc | 13 +++++++------ python/paddle/v2/framework/tests/test_program.py | 1 - 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 102fe2e67f..07bc66c51f 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector& names, } static void CreateGradVarInBlock( - std::unordered_map* grad_var_record, - BlockDescBind* block_desc, size_t grad_op_start_index, - const std::unordered_map& param_name_map) { + size_t grad_op_start_index, + const std::unordered_map& param_name_map, + BlockDescBind* block_desc, + std::unordered_map* grad_var_record) { auto ops = block_desc->AllOps(); for (size_t op_index = grad_op_start_index; op_index < ops.size(); ++op_index) { @@ -451,11 +452,11 @@ ParamGradInfoMap AppendBackward( root_block->NewVar(fill_one_op_out); // create grad_var for all blocks in this program - CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var); + CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv); for (size_t block_index = forward_block_num; block_index < program_desc.Size(); ++block_index) { - CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0, - grad_to_var); + CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index), + &retv); } return retv; } diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index cd209b0585..6ef806bee1 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -69,7 +69,6 @@ class TestProgram(unittest.TestCase): for op in block.all_ops(): actual_ops.append(op.type()) - print(actual_ops) self.assertEqual(actual_ops, expect_ops) -- GitLab