diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 9730fdd18bcf2f5011657876811a98cc4cbca859..c034e265fe4837ca22ab969b0e6952677904e05c 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -42,9 +42,9 @@ static std::shared_ptr NOP() { // // no_grad_names the gradient variable names without gradient calculating. // -// uniq_id is a unique index used inside recursively calling BackwardRecursive. -// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through -// recursive calling. +// uniq_id is a unique index used inside recursively calling +// BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and +// pass `uniq_id` through recursive calling. // // returns The backward operator. For simple situation, it is a simple // operator. For complex situation, it is a NetOp. @@ -64,8 +64,8 @@ std::shared_ptr BackwardRecursive( return NOP(); } - // All output gradients of forwarding operator do not need to calculate. Then - // all input gradients cannot be computed at all, and we put them into + // All output gradients of forwarding operator do not need to calculate. + // Then all input gradients cannot be computed at all, and we put them into // `no_grad_names` set. Return an NOP. if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { @@ -83,8 +83,8 @@ std::shared_ptr BackwardRecursive( // Because forwardOp is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); - // Map from output gradient variable name to operator's indices in backward - // net. That operator generates that variable. + // Map from output gradient variable name to operator's indices in + // backward net. That operator generates that variable. std::unordered_map> dup_output_ops; size_t local_op_id = 0; diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 8adf7e4365d6d044e551c9e66101c7ae023e7cf8..8f437e68041188831a17217099e0b0c96432cda4 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -162,8 +162,8 @@ TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(1UL, gop->inputs_.size()); - ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); + ASSERT_EQ(4UL, gop->inputs_.size()); + ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]); @@ -360,7 +360,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { 3UL /* external input number */ + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ - - 1UL /*ignoreGradient varable number*/ + 2U /* internal variable number*/); EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ + 2UL /* input number of rowwise_add */ diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index dd686cc78246f06cdc3ec7d013086863d7e8fac0..ea5e939c6e26514c2f3c515da5581b29103f75b6 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -8,107 +8,97 @@ You may obtain a copy of the License at Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either +express or implied. See the License for the specific language governing +permissions and limitations under the License. */ #include "paddle/framework/grad_op_builder.h" +#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace framework { -OperatorBase* GradOpBuilder::Build() { - BuildOpInOutArgList(); - std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_); - OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); - grad_op->type_ = grad_op_type; - CompleteGradOp(grad_op); - return grad_op; -} +class OpRegistry; + +using VarIndexMap = std::unordered_map; -OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, - const VarIndexMap& var_map, - const std::vector& format, - InOutType type) { - int idx = var_map.at(var.name()); - int begin_idx = format.empty() ? idx : format.at(idx); - int end_idx = format.empty() ? idx + 1 : format.at(idx + 1); - return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx, - end_idx); +enum class OpArgType { IN, OUT }; + +static std::vector* GetOpFormat(OperatorBase* op, const OpArgType& type) { + std::string key = type == OpArgType::IN ? "input_format" : "output_format"; + return op->attrs_.count(key) + ? &boost::get>(op->attrs_.at(key)) + : nullptr; } -void GradOpBuilder::BuildOpInOutArgList() { - const OpProto& op_proto = OpRegistry::protos().at(op_.type_); - const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_)); - const std::vector& in_format = - op_.attrs_.count("input_format") - ? op_.GetAttr>("input_format") - : std::vector(); - const std::vector& out_format = - op_.attrs_.count("output_format") - ? op_.GetAttr>("output_format") - : std::vector(); - for (const auto& var : op_proto.inputs()) { - arg_list_.emplace_back( - std::shared_ptr(BuildArg(var, var_map, in_format, IN))); - } - for (const auto& var : op_proto.outputs()) { - arg_list_.emplace_back( - std::shared_ptr(BuildArg(var, var_map, out_format, OUT))); - } +static const std::vector* GetOpFormat(const OperatorBase* op, + const OpArgType& type) { + std::string key = type == OpArgType::IN ? "input_format" : "output_format"; + return op->attrs_.count(key) + ? &boost::get>(op->attrs_.at(key)) + : nullptr; } -void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, - std::vector& in_out, - std::vector& format, - VarIndexMap* varmap, int& idx, - bool is_grad) const { - std::string var_name = arg->proto_name_; - if (is_grad) { - var_name += OperatorBase::GRAD_VAR_SUFFIX(); - } - (*varmap)[var_name] = idx++; - size_t pre_sz = in_out.size(); - auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin(); - std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, - std::back_inserter(in_out)); - if (is_grad) { - for (size_t i = pre_sz; i < in_out.size(); ++i) { - in_out[i] += OperatorBase::GRAD_VAR_SUFFIX(); +static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, + const OpArgType& src_type, const OpArgType& dst_type, + int& idx, bool is_grad) { + const std::vector& src_inout = + src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; + const std::vector* src_format = GetOpFormat(src_op, src_type); + + std::vector& dst_inout = + dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; + std::vector* dst_format = GetOpFormat(dst_op, dst_type); + const OpProto& proto = OpRegistry::protos().at(src_op->type_); + const auto& src_arg_list = + src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); + + for (const auto& arg : src_arg_list) { + std::string src_name = arg.name(); + std::string dst_name = + is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name; + (*dst_op->in_out_idxs_)[dst_name] = idx++; + int src_arg_idx = src_op->in_out_idxs_->at(src_name); + int src_begin = + src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); + int src_end = src_format == nullptr ? src_arg_idx + 1 + : src_format->at(src_arg_idx + 1); + for (int i = src_begin; i < src_end; ++i) { + std::string s = is_grad ? src_inout[i] + OperatorBase::GRAD_VAR_SUFFIX() + : arg.ignore_gradient() + ? OperatorBase::EMPTY_VAR_NAME() + : src_inout[i]; + dst_inout.emplace_back(s); + } + if (dst_format != nullptr) { + dst_format->push_back(dst_inout.size()); } } - format.push_back(in_out.size()); } -void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { - grad_op->attrs_ = op_.attrs_; +OperatorBase* BuildGradOp(const OperatorBase* op) { + std::string grad_op_type = OpRegistry::grad_ops().at(op->type_); + OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); + grad_op->type_ = grad_op_type; + grad_op->attrs_ = op->attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); - VarIndexMap* grad_varmap = new VarIndexMap(); + if (GetOpFormat(op, OpArgType::IN) != nullptr) { + grad_op->attrs_["output_format"] = std::vector({0}); + } + if (GetOpFormat(op, OpArgType::IN) != nullptr || + GetOpFormat(op, OpArgType::OUT) != nullptr) { + grad_op->attrs_["input_format"] = std::vector({0}); + } + grad_op->in_out_idxs_.reset(new VarIndexMap()); int in_idx = 0; int out_idx = 0; - std::vector in_format({0}); - std::vector out_format({0}); - for (const auto& arg : arg_list_) { - // op_'s inputs_ and outputs_ - if (arg->needed_in_grad_) { - AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, - in_idx, false); - } - if (arg->type_ == IN) { - // gradients of op_'s inputs_ - AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap, - out_idx, true); - } else { - // gradients of op_'s outputs_ - AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, - in_idx, true); - } - } - grad_op->attrs_["input_format"] = in_format; - grad_op->attrs_["output_format"] = out_format; - grad_op->in_out_idxs_.reset(grad_varmap); + TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I + TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G + TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG + TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG + return grad_op; } } // namespace framework diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h index cc7a76f3726e00a08fbe06bca4c9b9f5bad466b4..cf235de6c267a4a1feb7afd3e4dbe7a6a668ee5e 100644 --- a/paddle/framework/grad_op_builder.h +++ b/paddle/framework/grad_op_builder.h @@ -1,48 +1,11 @@ #pragma once -#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/operator.h" namespace paddle { namespace framework { -class OpRegistry; -enum InOutType { IN, OUT }; - -struct OpInOutArg { - OpInOutArg(const std::string& proto_name, const InOutType& type, - bool needed_in_grad, size_t begin_idx, size_t end_idx) - : proto_name_(proto_name), - type_(type), - needed_in_grad_(needed_in_grad), - begin_idx_(begin_idx), - end_idx_(end_idx) {} - - std::string proto_name_; - InOutType type_; - bool needed_in_grad_; - size_t begin_idx_; - size_t end_idx_; -}; - -class GradOpBuilder { - using VarIndexMap = std::unordered_map; - - public: - GradOpBuilder(const OperatorBase& op) : op_(op) {} - OperatorBase* Build(); - - private: - OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, - const std::vector& format, InOutType type); - void BuildOpInOutArgList(); - void AddArgIntoGradOp(const OpInOutArg* arg, std::vector& in_out, - std::vector& format, VarIndexMap* varmap, int& idx, - bool is_grad) const; - void CompleteGradOp(OperatorBase* grad_op) const; - const OperatorBase& op_; - std::vector> arg_list_; -}; +OperatorBase* BuildGradOp(const OperatorBase* op); } // namespace framework } // namespace paddle diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index e9cf3b9798db2cbfb8d26259ae9a6741fbae8278..96d7f309d67b15c000ab8ce3769931322fbca880 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -8,10 +8,49 @@ USE_OP(add_two); namespace paddle { namespace framework { +class NOP : public OperatorBase { + public: + void InferShape(const Scope &scope) const override {} + void Run(const Scope &scope, + const platform::DeviceContext &dev_ctx) const override {} +}; + +class MutiInOutOpMaker : public OpProtoAndCheckerMaker { + public: + MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("In1", "a single input"); + AddInput("In2_mult", "a multiple input").SetMultiple(); + AddInput("In3", "another single input"); + AddOutput("Out1", "a single output"); + AddOutput("Out2_mult", "a multiple output").SetMultiple(); + AddComment("test op with multiple inputs and outputs"); + } +}; + +class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { + public: + IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("In1", "a single input"); + AddInput("In2_mult", "a multiple input").SetMultiple().IgnoreGradient(); + AddInput("In3_mult", "another multiple input").SetMultiple(); + AddOutput("Out1_mult", "a multiple output").SetMultiple(); + AddOutput("Out2", "a single output").IgnoreGradient(); + AddComment("op with inputs and outputs ignored in gradient calculating"); + } +}; + +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; + TEST(GradOpBuilder, AddTwo) { - std::shared_ptr add_op( - OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); - std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(*add_op); + std::shared_ptr add_op( + f::OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); + std::shared_ptr grad_add_op = + f::OpRegistry::CreateGradOp(*add_op); EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Input("X"), "x"); @@ -22,5 +61,85 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); } -} // namespace framework -} // namespace paddle \ No newline at end of file +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); +REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker); +REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP); + +TEST(GradOpBuilder, MutiInOut) { + f::AttributeMap attrs{{"input_format", std::vector{0, 1, 4, 5}}, + {"output_format", std::vector{0, 1, 3}}}; + std::shared_ptr test_op(f::OpRegistry::CreateOp( + "mult_io", {"in1", "in2_1", "in2_2", "in2_3", "in3"}, + {"out1", "out2_1", "out2_2"}, attrs)); + std::shared_ptr grad_test_op = + f::OpRegistry::CreateGradOp(*test_op); + + ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + EXPECT_EQ(grad_test_op->Input("In1"), "in1"); + EXPECT_EQ(grad_test_op->Inputs("In2_mult"), + std::vector({"in2_1", "in2_2", "in2_3"})); + EXPECT_EQ(grad_test_op->Input("In3"), "in3"); + EXPECT_EQ(grad_test_op->Input("Out1"), "out1"); + EXPECT_EQ(grad_test_op->Inputs("Out2_mult"), + std::vector({"out2_1", "out2_2"})); + EXPECT_EQ(grad_test_op->Input("Out1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Inputs("Out2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector( + {"out2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "out2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + + ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); + EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_3" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ(grad_test_op->Output("In3" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in3" + f::OperatorBase::GRAD_VAR_SUFFIX()); +} + +TEST(GradOpBuilder, IOIgnoredInGradient) { + f::AttributeMap attrs{{"input_format", std::vector{0, 1, 3, 5}}, + {"output_format", std::vector{0, 2, 3}}}; + std::shared_ptr test_op(f::OpRegistry::CreateOp( + "io_ignored", {"in1", "in2_1", "in2_2", "in3_1", "in3_2"}, + {"out1_1", "out1_2", "out2"}, attrs)); + std::shared_ptr grad_test_op = + f::OpRegistry::CreateGradOp(*test_op); + + // 'In2' and 'Out2' are ignored in gradient calculating + ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + EXPECT_EQ(grad_test_op->Input("In1"), "in1"); + EXPECT_EQ(grad_test_op->Inputs("In2_mult"), + std::vector({f::OperatorBase::EMPTY_VAR_NAME(), + f::OperatorBase::EMPTY_VAR_NAME()})); + EXPECT_EQ(grad_test_op->Inputs("In3_mult"), + std::vector({"in3_1", "in3_2"})); + EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), + std::vector({"out1_1", "out1_2"})); + EXPECT_EQ(grad_test_op->Input("Out2"), f::OperatorBase::EMPTY_VAR_NAME()); + EXPECT_EQ( + grad_test_op->Inputs("Out1_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector( + {"out1_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "out1_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ(grad_test_op->Input("Out2" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out2" + f::OperatorBase::GRAD_VAR_SUFFIX()); + + ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); + EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ( + grad_test_op->Outputs("In3_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in3_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in3_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); +} diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 3e72e391266066de9e4114e68b43b066c15254db..9a975185f04da8df5ba22e457936218756e7c4bc 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -306,8 +306,7 @@ class OpRegistry { static std::shared_ptr CreateGradOp(const OperatorBase& op) { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); - GradOpBuilder builder(op); - std::shared_ptr grad_op(builder.Build()); + std::shared_ptr grad_op(BuildGradOp(&op)); grad_op->Init(); return grad_op; }