diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 5f84eb8c15df3045a79544539e1e53d81b3145af..35db0cf716d04bdc2fd6a9789fcf11f56efdc7f6 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -21,8 +21,6 @@ namespace framework { class OpRegistry; -using VarIndexMap = std::unordered_map; - enum class OpArgType { IN, OUT }; static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, @@ -30,19 +28,19 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, bool is_grad) { const auto& src_inout = src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; - auto& dst_inout = dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; + const OpProto& proto = OpProtos().at(src_op->type_); const auto& src_arg_list = src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); - for (const auto& arg : src_arg_list) { - const std::string& src_name = arg.name(); + if (arg.no_gradient() && !is_grad) continue; + const std::string src_name = arg.name(); std::string dst_name = is_grad ? GradVarName(src_name) : src_name; + dst_inout[dst_name].reserve(src_inout.at(src_name).size()); for (auto& var_name : src_inout.at(src_name)) { - std::string s = is_grad ? GradVarName(var_name) - : (arg.no_gradient() ? kEmptyVarName : var_name); + std::string s = is_grad ? GradVarName(var_name) : var_name; dst_inout[dst_name].emplace_back(s); } } diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 19da90967f05bf2727be699ccd26314b14220230..85e745322b9b8ad9d73791387644915100e3926b 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -110,15 +110,12 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { f::OpRegistry::CreateGradOp(*test_op); // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL); + ASSERT_EQ(grad_test_op->inputs_.size(), 2UL + 1UL + 2UL); EXPECT_EQ(grad_test_op->Input("In1"), "in1"); - EXPECT_EQ(grad_test_op->Inputs("In2_mult"), - std::vector({f::kEmptyVarName, f::kEmptyVarName})); EXPECT_EQ(grad_test_op->Inputs("In3_mult"), std::vector({"in3_1", "in3_2"})); EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), std::vector({"out1_1", "out1_2"})); - EXPECT_EQ(grad_test_op->Input("Out2"), f::kEmptyVarName); EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")), std::vector( {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 03b14ea021605147873e641561a11dd9618fd81d..bb23b6bf65d3fb182f706cc00647b87796995c0f 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -120,7 +120,6 @@ class OpProtoAndCheckerMaker { class OpRegistry { using OpCreator = std::function; - using VarIndexMap = std::unordered_map; using VarNameMap = std::unordered_map>; public: diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 173a701fa6f6b0d8e34b38ac369844d065f78319..c7ee76500312aa1c0d3f219f000250127e815a96 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -59,8 +59,8 @@ void ExposeOperator(ClassType &m) { .def("outputs", [](const typename ClassType::type &op) -> std::unordered_map> { - return op.outputs_; - }) + return op.outputs_; + }) .def("__str__", &ClassType::type::DebugString); }