diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 0ec18de5b8a0e7cebdb91c30d2b45596b02dfa51..ab2567a25c04f3893d08dcf7b80305472d8363cc 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -154,6 +154,9 @@ static std::unique_ptr BackwardRecursive( net->InsertOp(pos.first + 1, std::move(pos.second)); } } else { + OpDescBind fwd_desc; + fwd_desc.SetInput(forwardOp.Inputs()); + std::unique_ptr grad_op(OpRegistry::CreateGradOp(forwardOp)); ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 6932f5b989a3e21ebc44ec4fec9f5223f2547d7a..28fc6f9ced07b11373b8abf92e99e497d8ae2377 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -159,16 +159,16 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, f::NOP); -TEST(Backward, simple_op_grad) { - auto fwd = f::OpRegistry::CreateOp( - "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); - ASSERT_NE(fwd, nullptr); - auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(1UL, gop->Inputs().size()); - ASSERT_EQ("rowwise_add_grad", gop->Type()); - ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); - ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); -} +// TEST(Backward, simple_op_grad) { +// auto fwd = f::OpRegistry::CreateOp( +// "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); +// ASSERT_NE(fwd, nullptr); +// auto gop = f::OpRegistry::CreateGradOp(*fwd); +// ASSERT_EQ(1UL, gop->Inputs().size()); +// ASSERT_EQ("rowwise_add_grad", gop->Type()); +// ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); +// ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); +//} TEST(Backward, simple_op_not_need_grad) { auto fwd = f::OpRegistry::CreateOp( @@ -286,17 +286,6 @@ TEST(Backward, net_shared_weight) { ASSERT_EQ("add", bwd_net->ops_[2]->Type()); } -TEST(Backward, op_register_grad_not_for_network) { - auto fwd = - f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}}, - {{"mul_result", {"mul_out"}}, - {"add_result", {"add_out"}}, - {"Out", {"out1"}}}, - {{"temporary_index", std::vector{0, 1}}}); - - ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); -} - TEST(Backward, op_all_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp( "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc deleted file mode 100644 index 3661ce41beba1328d1b1cdd9f0f913e693af9cff..0000000000000000000000000000000000000000 --- a/paddle/framework/grad_op_builder.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either -express or implied. See the License for the specific language governing -permissions and limitations under the License. */ - -#include "paddle/framework/grad_op_builder.h" -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace framework { -enum class OpArgType { IN, OUT }; - -static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, - bool is_grad, VariableNameMap* vars) { - const auto& src_inout = - src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); - auto& dst_inout = *vars; - auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto(); - const auto& src_arg_list = - src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); - for (const auto& arg : src_arg_list) { - if (arg.not_in_gradient() && !is_grad) continue; - const std::string src_name = arg.name(); - std::string dst_name = is_grad ? GradVarName(src_name) : src_name; - dst_inout[dst_name].reserve(src_inout.at(src_name).size()); - for (auto& var_name : src_inout.at(src_name)) { - std::string s = is_grad ? GradVarName(var_name) : var_name; - dst_inout[dst_name].emplace_back(s); - } - } -} - -OperatorBase* BuildGradOp(const OperatorBase* op) { - auto& info = OpInfoMap::Instance().Get(op->Type()); - PADDLE_ENFORCE(info.HasGradientOp()); - - VariableNameMap inputs; - VariableNameMap outputs; - TransOpArg(op, OpArgType::IN, false, &inputs); // I - TransOpArg(op, OpArgType::OUT, false, &inputs); // O - TransOpArg(op, OpArgType::OUT, true, &inputs); // OG - TransOpArg(op, OpArgType::IN, true, &outputs); // IG - - auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_); - return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs()); -} - -static void TransOpDescArg(const OpDescBind* src_op, const OpArgType& src_type, - bool is_grad, OpDescBind* dst_op, - const OpArgType& dst_type) { - PADDLE_ENFORCE(dst_op != nullptr, - "Protobuf desc of gradient op must be initialized first."); - const auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto(); - const auto& src_arg_list = - src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); - for (const auto& arg : src_arg_list) { - if (arg.not_in_gradient() && !is_grad) continue; - const std::string src_name = arg.name(); - std::vector vars = src_type == OpArgType::IN - ? src_op->Input(src_name) - : src_op->Output(src_name); - if (is_grad) { - for (std::string& var : vars) { - var = GradVarName(var); - } - } - std::string dst_name = is_grad ? GradVarName(src_name) : src_name; - dst_type == OpArgType::IN ? dst_op->SetInput(dst_name, vars) - : dst_op->SetOutput(dst_name, vars); - } -} - -void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op) { - auto& info = OpInfoMap::Instance().Get(forw_op->Type()); - PADDLE_ENFORCE(info.HasGradientOp()); - - grad_op->SetType(info.grad_op_type_); - - TransOpDescArg(forw_op, OpArgType::IN, false, grad_op, OpArgType::IN); - TransOpDescArg(forw_op, OpArgType::OUT, false, grad_op, OpArgType::IN); - TransOpDescArg(forw_op, OpArgType::OUT, true, grad_op, OpArgType::IN); - TransOpDescArg(forw_op, OpArgType::IN, true, grad_op, OpArgType::OUT); - - grad_op->SetAttrMap(forw_op->GetAttrMap()); -} - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h deleted file mode 100644 index b601406061f9f8f24302251c2144b07b6e65717f..0000000000000000000000000000000000000000 --- a/paddle/framework/grad_op_builder.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/op_desc.h" -#include "paddle/framework/operator.h" - -namespace paddle { -namespace framework { - -OperatorBase* BuildGradOp(const OperatorBase* op); - -void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op); - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc deleted file mode 100644 index d09892f81bea34415d454b017258fd2a0d4575db..0000000000000000000000000000000000000000 --- a/paddle/framework/grad_op_builder_test.cc +++ /dev/null @@ -1,201 +0,0 @@ -#include "paddle/framework/grad_op_builder.h" -#include -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" - -USE_OP(add); - -namespace paddle { -namespace framework { - -class MutiInOutOpMaker : public OpProtoAndCheckerMaker { - public: - MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").AsDuplicable(); - AddInput("In3", "another single input"); - AddOutput("Out1", "a single output"); - AddOutput("Out2_mult", "a multiple output").AsDuplicable(); - AddComment("test op with multiple inputs and outputs"); - } -}; - -class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { - public: - IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient(); - AddInput("In3_mult", "another multiple input").AsDuplicable(); - AddOutput("Out1_mult", "a multiple output").AsDuplicable(); - AddOutput("Out2", "a single output").NotInGradient(); - AddComment("op with inputs and outputs ignored in gradient calculating"); - } -}; - -} // namespace framework -} // namespace paddle - -namespace f = paddle::framework; - -TEST(GradOpBuilder, AddTwo) { - std::shared_ptr add_op(f::OpRegistry::CreateOp( - "add", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {})); - std::shared_ptr grad_add_op = - f::OpRegistry::CreateGradOp(*add_op); - EXPECT_EQ(grad_add_op->Inputs().size(), 4UL); - EXPECT_EQ(grad_add_op->Outputs().size(), 2UL); - EXPECT_EQ(grad_add_op->Input("X"), "x"); - EXPECT_EQ(grad_add_op->Input("Y"), "y"); - EXPECT_EQ(grad_add_op->Input("Out"), "out"); - EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out")); - EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x")); - EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); -} - -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); - -TEST(GradOpBuilder, MutiInOut) { - std::shared_ptr test_op(f::OpRegistry::CreateOp( - "mult_io", {{"In1", {"in1"}}, - {"In2_mult", {"in2_1", "in2_2", "in2_3"}}, - {"In3", {"in3"}}}, - {{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {})); - std::shared_ptr grad_test_op = - f::OpRegistry::CreateGradOp(*test_op); - - ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL); - EXPECT_EQ(grad_test_op->Input("In1"), "in1"); - EXPECT_EQ(grad_test_op->Inputs("In2_mult"), - std::vector({"in2_1", "in2_2", "in2_3"})); - EXPECT_EQ(grad_test_op->Input("In3"), "in3"); - EXPECT_EQ(grad_test_op->Input("Out1"), "out1"); - EXPECT_EQ(grad_test_op->Inputs("Out2_mult"), - std::vector({"out2_1", "out2_2"})); - EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out1")), - f::GradVarName("out1")); - EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out2_mult")), - std::vector( - {f::GradVarName("out2_1"), f::GradVarName("out2_2")})); - - ASSERT_EQ(grad_test_op->Outputs().size(), 3UL); - EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); - EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), - std::vector({f::GradVarName("in2_1"), - f::GradVarName("in2_2"), - f::GradVarName("in2_3")})); - EXPECT_EQ(grad_test_op->Output(f::GradVarName("In3")), f::GradVarName("in3")); -} - -TEST(GradOpBuilder, IOIgnoredInGradient) { - std::shared_ptr test_op(f::OpRegistry::CreateOp( - "io_ignored", {{"In1", {"in1"}}, - {"In2_mult", {"in2_1", "in2_2"}}, - {"In3_mult", {"in3_1", "in3_2"}}}, - {{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {})); - std::shared_ptr grad_test_op = - f::OpRegistry::CreateGradOp(*test_op); - - // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL); - EXPECT_EQ(grad_test_op->Input("In1"), "in1"); - EXPECT_EQ(grad_test_op->Inputs("In3_mult"), - std::vector({"in3_1", "in3_2"})); - EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), - std::vector({"out1_1", "out1_2"})); - EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")), - std::vector( - {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); - EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")), - f::GradVarName("out2")); - - ASSERT_EQ(grad_test_op->Outputs().size(), 3UL); - EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); - EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), - std::vector( - {f::GradVarName("in2_1"), f::GradVarName("in2_2")})); - EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In3_mult")), - std::vector( - {f::GradVarName("in3_1"), f::GradVarName("in3_2")})); -} - -TEST(GradOpDescBuilder, MutiInOut) { - f::OpDescBind *forw_op = new f::OpDescBind(); - forw_op->SetType("mult_io"); - forw_op->SetInput("In1", {"in1"}); - forw_op->SetInput("In2_mult", {"in2_1", "in2_2", "in2_3"}); - forw_op->SetInput("In3", {"in3"}); - forw_op->SetOutput("Out1", {"out1"}); - forw_op->SetOutput("Out2_mult", {"out2_1", "out2_2"}); - - f::OpDescBind *grad_op = new f::OpDescBind(); - f::CompleteGradOpDesc(forw_op, grad_op); - - EXPECT_EQ(grad_op->Type(), "mult_io_grad"); - ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL); - EXPECT_EQ(grad_op->Input("In1"), std::vector({"in1"})); - EXPECT_EQ(grad_op->Input("In2_mult"), - std::vector({"in2_1", "in2_2", "in2_3"})); - EXPECT_EQ(grad_op->Input("In3"), std::vector({"in3"})); - EXPECT_EQ(grad_op->Input("Out1"), std::vector({"out1"})); - EXPECT_EQ(grad_op->Input("Out2_mult"), - std::vector({"out2_1", "out2_2"})); - EXPECT_EQ(grad_op->Input(f::GradVarName("Out1")), - std::vector({f::GradVarName("out1")})); - EXPECT_EQ(grad_op->Input(f::GradVarName("Out2_mult")), - std::vector( - {f::GradVarName("out2_1"), f::GradVarName("out2_2")})); - - ASSERT_EQ(grad_op->OutputNames().size(), 3UL); - EXPECT_EQ(grad_op->Output(f::GradVarName("In1")), - std::vector({f::GradVarName("in1")})); - EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")), - std::vector({f::GradVarName("in2_1"), - f::GradVarName("in2_2"), - f::GradVarName("in2_3")})); - EXPECT_EQ(grad_op->Output(f::GradVarName("In3")), - std::vector({f::GradVarName("in3")})); - delete forw_op; - delete grad_op; -} - -TEST(GradOpDescBuilder, IOIgnoredInGradient) { - f::OpDescBind *forw_op = new f::OpDescBind(); - forw_op->SetType("io_ignored"); - forw_op->SetInput("In1", {"in1"}); - forw_op->SetInput("In2_mult", {"in2_1", "in2_2"}); - forw_op->SetInput("In3_mult", {"in3_1", "in3_2"}); - forw_op->SetOutput("Out1_mult", {"out1_1", "out1_2"}); - forw_op->SetOutput("Out2", {"out2"}); - - f::OpDescBind *grad_op = new f::OpDescBind(); - f::CompleteGradOpDesc(forw_op, grad_op); - - EXPECT_EQ(grad_op->Type(), "io_ignored_grad"); - // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_op->InputNames().size(), 2UL + 1UL + 2UL); - EXPECT_EQ(grad_op->Input("In1"), std::vector({"in1"})); - EXPECT_EQ(grad_op->Input("In3_mult"), - std::vector({"in3_1", "in3_2"})); - EXPECT_EQ(grad_op->Input("Out1_mult"), - std::vector({"out1_1", "out1_2"})); - EXPECT_EQ(grad_op->Input(f::GradVarName("Out1_mult")), - std::vector( - {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); - EXPECT_EQ(grad_op->Input(f::GradVarName("Out2")), - std::vector({f::GradVarName("out2")})); - - ASSERT_EQ(grad_op->OutputNames().size(), 3UL); - EXPECT_EQ(grad_op->Output(f::GradVarName("In1")), - std::vector({f::GradVarName("in1")})); - EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")), - std::vector( - {f::GradVarName("in2_1"), f::GradVarName("in2_2")})); - EXPECT_EQ(grad_op->Output(f::GradVarName("In3_mult")), - std::vector( - {f::GradVarName("in3_1"), f::GradVarName("in3_2")})); - delete forw_op; - delete grad_op; -} \ No newline at end of file diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 851a305061c1f402be5c903852762d89f15c267c..4b001fb9642e56db272c947300f7df6a03b9a57a 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -74,6 +74,18 @@ class OpDescBind { return MapKeys(outputs_); } + void SetInput( + const std::unordered_map> &input) { + this->inputs_ = input; + this->need_update_ = true; + } + + void SetOutput( + const std::unordered_map> &output) { + this->outputs_ = output; + this->need_update_ = true; + } + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index b0e85dd49f97da4a7f889fde0b5f060954947be8..0a2b6fd582a22a7f901c6332c7e84566cd98b0c0 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -52,10 +52,5 @@ std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { return CreateOp(op_desc.type(), inputs, outputs, attrs); } -std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { - PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); - return std::unique_ptr(BuildGradOp(&op)); -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 7db095369e8356915fc88c7c3bd11141b464608f..0f377f34cbfaa6ae516afc58f43abc6699a973d1 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -79,8 +79,6 @@ class OpRegistry { AttributeMap attrs); static std::unique_ptr CreateOp(const OpDesc& op_desc); - - static std::unique_ptr CreateGradOp(const OperatorBase& op); }; template