From ac5893e8ccbccb37d9868db57155ecbb032d3734 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 10 Aug 2017 19:01:00 +0800 Subject: [PATCH] Fix grad_op_builder --- paddle/framework/grad_op_builder.cc | 5 +---- paddle/framework/grad_op_builder_test.cc | 5 +---- paddle/framework/op_registry.h | 1 - 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index c51a563a612..35db0cf716d 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -21,8 +21,6 @@ namespace framework { class OpRegistry; -using VarIndexMap = std::unordered_map; - enum class OpArgType { IN, OUT }; static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, @@ -36,10 +34,9 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, const OpProto& proto = OpProtos().at(src_op->type_); const auto& src_arg_list = src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); - for (const auto& arg : src_arg_list) { if (arg.no_gradient() && !is_grad) continue; - std::string src_name = arg.name(); + const std::string src_name = arg.name(); std::string dst_name = is_grad ? GradVarName(src_name) : src_name; dst_inout[dst_name].reserve(src_inout.at(src_name).size()); for (auto& var_name : src_inout.at(src_name)) { diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 19da90967f0..85e745322b9 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -110,15 +110,12 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { f::OpRegistry::CreateGradOp(*test_op); // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL); + ASSERT_EQ(grad_test_op->inputs_.size(), 2UL + 1UL + 2UL); EXPECT_EQ(grad_test_op->Input("In1"), "in1"); - EXPECT_EQ(grad_test_op->Inputs("In2_mult"), - std::vector({f::kEmptyVarName, f::kEmptyVarName})); EXPECT_EQ(grad_test_op->Inputs("In3_mult"), std::vector({"in3_1", "in3_2"})); EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), std::vector({"out1_1", "out1_2"})); - EXPECT_EQ(grad_test_op->Input("Out2"), f::kEmptyVarName); EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")), std::vector( {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 03b14ea0216..bb23b6bf65d 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -120,7 +120,6 @@ class OpProtoAndCheckerMaker { class OpRegistry { using OpCreator = std::function; - using VarIndexMap = std::unordered_map; using VarNameMap = std::unordered_map>; public: -- GitLab