diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 4aaa43d79612111856dd4dfc954ca2bfd8f4fa63..8a5d8532bb32db917b893f7f59039e08d85c8c34 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -26,7 +26,7 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) +cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index b02a599a800668b22e7fe39a10fa6dc132e305bd..3661ce41beba1328d1b1cdd9f0f913e693af9cff 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -54,5 +54,44 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs()); } +static void TransOpDescArg(const OpDescBind* src_op, const OpArgType& src_type, + bool is_grad, OpDescBind* dst_op, + const OpArgType& dst_type) { + PADDLE_ENFORCE(dst_op != nullptr, + "Protobuf desc of gradient op must be initialized first."); + const auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto(); + const auto& src_arg_list = + src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); + for (const auto& arg : src_arg_list) { + if (arg.not_in_gradient() && !is_grad) continue; + const std::string src_name = arg.name(); + std::vector vars = src_type == OpArgType::IN + ? src_op->Input(src_name) + : src_op->Output(src_name); + if (is_grad) { + for (std::string& var : vars) { + var = GradVarName(var); + } + } + std::string dst_name = is_grad ? GradVarName(src_name) : src_name; + dst_type == OpArgType::IN ? dst_op->SetInput(dst_name, vars) + : dst_op->SetOutput(dst_name, vars); + } +} + +void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op) { + auto& info = OpInfoMap::Instance().Get(forw_op->Type()); + PADDLE_ENFORCE(info.HasGradientOp()); + + grad_op->SetType(info.grad_op_type_); + + TransOpDescArg(forw_op, OpArgType::IN, false, grad_op, OpArgType::IN); + TransOpDescArg(forw_op, OpArgType::OUT, false, grad_op, OpArgType::IN); + TransOpDescArg(forw_op, OpArgType::OUT, true, grad_op, OpArgType::IN); + TransOpDescArg(forw_op, OpArgType::IN, true, grad_op, OpArgType::OUT); + + grad_op->SetAttrMap(forw_op->GetAttrMap()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h index 998f8ebbb5f2f4fb8b7e938b5916afd0f8a7930d..b601406061f9f8f24302251c2144b07b6e65717f 100644 --- a/paddle/framework/grad_op_builder.h +++ b/paddle/framework/grad_op_builder.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/framework/op_desc.h" #include "paddle/framework/operator.h" namespace paddle { @@ -21,5 +22,7 @@ namespace framework { OperatorBase* BuildGradOp(const OperatorBase* op); +void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 9e3ca563c6765637f8471d142d32cec447f0b977..85184e02b6e1a2700a337fce690370d1c5a0346f 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -120,3 +120,40 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { std::vector( {f::GradVarName("in3_1"), f::GradVarName("in3_2")})); } + +TEST(GradOpDescBuilder, MutiInOut) { + f::OpDescBind *forw_op = new f::OpDescBind(); + forw_op->SetType("mult_io"); + forw_op->SetInput("In1", {"in1"}); + forw_op->SetInput("In2_mult", {"in2_1", "in2_2", "in2_3"}); + forw_op->SetInput("In3", {"in3"}); + forw_op->SetOutput("Out1", {"out1"}); + forw_op->SetOutput("Out2_mult", {"out2_1", "out2_2"}); + + f::OpDescBind *grad_op = new f::OpDescBind(); + f::CompleteGradOpDesc(forw_op, grad_op); + + ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL); + EXPECT_EQ(grad_op->Input("In1"), std::vector({"in1"})); + EXPECT_EQ(grad_op->Input("In2_mult"), + std::vector({"in2_1", "in2_2", "in2_3"})); + EXPECT_EQ(grad_op->Input("In3"), std::vector({"in3"})); + EXPECT_EQ(grad_op->Input("Out1"), std::vector({"out1"})); + EXPECT_EQ(grad_op->Input("Out2_mult"), + std::vector({"out2_1", "out2_2"})); + EXPECT_EQ(grad_op->Input(f::GradVarName("Out1")), + std::vector({f::GradVarName("out1")})); + EXPECT_EQ(grad_op->Input(f::GradVarName("Out2_mult")), + std::vector( + {f::GradVarName("out2_1"), f::GradVarName("out2_2")})); + + ASSERT_EQ(grad_op->OutputNames().size(), 3UL); + EXPECT_EQ(grad_op->Output(f::GradVarName("In1")), + std::vector({f::GradVarName("in1")})); + EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")), + std::vector({f::GradVarName("in2_1"), + f::GradVarName("in2_2"), + f::GradVarName("in2_3")})); + EXPECT_EQ(grad_op->Output(f::GradVarName("In3")), + std::vector({f::GradVarName("in3")})); +} diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 99b5a9c37700adce56f9a83af3792ef113a873ff..0c12c55dc09f6aa064066b5c73bc5e985a57343f 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -89,6 +89,12 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { need_update_ = true; } +void OpDescBind::SetAttrMap( + const std::unordered_map &attr_map) { + attrs_ = attr_map; + need_update_ = true; +} + Attribute OpDescBind::GetAttr(const std::string &name) const { auto it = attrs_.find(name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); @@ -101,6 +107,11 @@ int OpDescBind::GetBlockAttr(const std::string &name) const { return boost::get(it->second)->idx(); } +const std::unordered_map &OpDescBind::GetAttrMap() + const { + return attrs_; +} + void OpDescBind::Sync() { if (need_update_) { this->op_desc_.mutable_inputs()->Clear(); diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index ffc8ac61abfb74e4716f10c457d0fbc18b2e2ab8..0cf7d13971675eb825bcd0c7636896f0862d6ebb 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -60,10 +60,16 @@ class OpDescBind { void SetBlockAttr(const std::string &name, BlockDescBind &block); + // Only be used in C++ + void SetAttrMap(const std::unordered_map &attr_map); + Attribute GetAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const; + // Only be used in C++ + const std::unordered_map &GetAttrMap() const; + private: struct SetAttrDescVisitor : public boost::static_visitor { explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}