diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 063b108500d95c94d5859cf6e1a5a88dcdb2ed31..c966f97c2d5b553f6ab67bb2f7aac27108b80409 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -28,14 +28,15 @@ namespace paddle { namespace framework { static inline std::unique_ptr CreateGradOp( - const OperatorBase& op) { + const OperatorBase& op, + const std::unordered_set& no_grad_set) { OpDescBind op_desc; op_desc.SetInputMap(op.Inputs()); op_desc.SetOutputMap(op.Outputs()); op_desc.SetType(op.Type()); op_desc.SetAttrMap(op.Attrs()); auto& info = OpInfoMap::Instance().Get(op.Type()); - auto grad_descs = info.GradOpMaker()(op_desc); + auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set); std::vector> grad_ops; grad_ops.reserve(grad_descs.size()); std::transform(grad_descs.begin(), grad_descs.end(), @@ -187,7 +188,8 @@ static std::unique_ptr BackwardRecursive( net->InsertOp(pos.first + 1, std::move(pos.second)); } } else { - std::unique_ptr grad_op(CreateGradOp(forwardOp)); + std::unique_ptr grad_op( + CreateGradOp(forwardOp, no_grad_names)); ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( const std::string& grad_input) { @@ -272,7 +274,7 @@ std::vector> MakeOpGrad( const std::unique_ptr& op_desc, std::unordered_set& no_grad_vars) { std::vector> grad_op_descs; - // All input gradients of forwarding operator do not need to calculat. + // All input gradients of forwarding operator do not need to calculate. const std::vector& inputs = op_desc->InputArgumentNames(); if (AllGradInSet(inputs, no_grad_vars)) { return grad_op_descs; // empty vector @@ -286,7 +288,9 @@ std::vector> MakeOpGrad( return grad_op_descs; // empty vector } - grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get()); + grad_op_descs = OpInfoMap::Instance() + .Get(op_desc->Type()) + .GradOpMaker()(*op_desc, no_grad_vars); std::list> pending_fill_zeros_ops; for (auto& desc : grad_op_descs) { @@ -301,11 +305,6 @@ std::vector> MakeOpGrad( pending_fill_zeros_ops.push_back(std::move(fill_zeros_op)); } } - for (const std::string& out_name : desc->OutputArgumentNames()) { - if (no_grad_vars.count(out_name)) { - desc->Rename(out_name, kEmptyVarName); - } - } } for (auto& p : pending_fill_zeros_ops) { diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 8fd8c826187e3e3f830fdb318c4edacbad8b7333..9b15331dff380a2c437e1cc42da92d7fd239d31c 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -169,6 +169,45 @@ class MultInOutOpMaker : public OpProtoAndCheckerMaker { } }; +class MinusGradOpDescMaker : public GradOpDescMakerBase { + public: + using GradOpDescMakerBase::GradOpDescMakerBase; + + std::vector> operator()() const override { + std::vector> retv; + auto x_g = InputGrad("X"); + if (!x_g.empty()) { + auto *op_desc = new OpDescBind(); + op_desc->SetType("scale"); + op_desc->SetInput("X", OutputGrad("Out")); + op_desc->SetOutput("Out", x_g); + op_desc->SetAttr("scale", 1.0f); + retv.emplace_back(op_desc); + } + + auto y_g = InputGrad("Y"); + if (!y_g.empty()) { + auto *op_desc = new OpDescBind(); + op_desc->SetType("scale"); + op_desc->SetInput("X", OutputGrad("Out")); + op_desc->SetOutput("Out", y_g); + op_desc->SetAttr("scale", -1.0f); + retv.emplace_back(op_desc); + } + return retv; + } +}; + +class MinusOpMaker : public OpProtoAndCheckerMaker { + public: + MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", ""); + AddInput("Y", ""); + AddOutput("Out", ""); + AddComment("minus for unittest"); + } +}; } // namespace framework } // namespace paddle @@ -187,6 +226,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, f::NOP); REGISTER_OP(mult_in_out, f::NOP, f::MultInOutOpMaker, mult_in_out_grad, f::NOP); +REGISTER_OPERATOR(minus, f::NOP, f::MinusOpMaker, f::MinusGradOpDescMaker); TEST(Backward, simple_op_not_need_grad) { auto fwd = f::OpRegistry::CreateOp( @@ -395,12 +435,13 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { 2UL /* external input number */ + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ - + 2U /* internal variable number*/); + + 2UL /* internal variable number*/ + ); EXPECT_EQ(grad_fc.Outputs(all).size(), 2UL /* input number of mul*/ - + 2UL /* input number of rowwise_add - */ - + 1UL /* input number of sigmod */); + + 2UL /* input number of rowwise_add*/ + + 1UL /* input number of sigmod */ + - 1UL /* out2 is not needed*/); EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL); @@ -580,8 +621,7 @@ TEST(Backward, intermedia_var_no_grad) { std::vector({f::GradVarName("out4")})); EXPECT_EQ(grad_op4->Output(f::GradVarName("X")), std::vector({f::GradVarName("out1")})); - EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), - std::vector({f::kEmptyVarName})); + EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), std::vector()); } TEST(Backward, var_no_grad) { @@ -619,8 +659,7 @@ TEST(Backward, var_no_grad) { std::vector({f::GradVarName("z2")})); EXPECT_EQ(grad_op2->Output(f::GradVarName("X")), std::vector({f::GradVarName("y1")})); - EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), - std::vector({f::kEmptyVarName})); + EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), std::vector()); f::OpDescBind *fill_zero_op = block->AllOps()[3]; ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like"); @@ -718,4 +757,19 @@ TEST(Backward, shared_var) { std::vector({f::GradVarName("x1")})); EXPECT_EQ(grad_op1->Output(f::GradVarName("b")), std::vector({f::GradVarName("b1")})); +} + +TEST(Backward, half_backward) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + auto *op1 = block->AppendOp(); + op1->SetType("minus"); + op1->SetInput("X", {"a"}); + op1->SetInput("Y", {"b"}); + op1->SetOutput("Out", {"out"}); + + AppendBackward(program, {"b"}); + auto ops = block->AllOps(); + ASSERT_EQ(2UL, ops.size()); } \ No newline at end of file diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h index daa474e8c5a223589018720da29a5c3363b5934d..ca8584b78ab081138e0d73b8a71ae4cc111a1b4c 100644 --- a/paddle/framework/details/op_registry.h +++ b/paddle/framework/details/op_registry.h @@ -97,8 +97,10 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { - info->grad_op_maker_ = [](const OpDescBind& fwd_op) { - T maker(fwd_op); + info->grad_op_maker_ = []( + const OpDescBind& fwd_op, + const std::unordered_set& no_grad_set) { + T maker(fwd_op, no_grad_set); return maker(); }; } diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h index e9ae6e22060850fe229998d3b651d08a5ca2033a..d7366b11ec94403e0d8d5d8a3485896f0dc691c0 100644 --- a/paddle/framework/grad_op_desc_maker.h +++ b/paddle/framework/grad_op_desc_maker.h @@ -13,6 +13,8 @@ limitations under the License. */ #pragma once +#include +#include #include "paddle/framework/op_desc.h" #include "paddle/framework/operator.h" @@ -21,27 +23,44 @@ namespace framework { class GradOpDescMakerBase { public: - explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {} + explicit GradOpDescMakerBase( + const OpDescBind& fwd_op, + const std::unordered_set& no_grad_set) + : fwd_op_(fwd_op), no_grad_set_(no_grad_set) {} virtual ~GradOpDescMakerBase() = default; virtual std::vector> operator()() const = 0; protected: - static std::vector ToGradNames( - const std::vector& var_names) { + std::vector InputGrad(const std::string& name, + bool drop_empty_grad = true) const { std::vector ret_val; + auto var_names = this->Input(name); ret_val.reserve(var_names.size()); - std::transform(var_names.begin(), var_names.end(), - std::back_inserter(ret_val), GradVarName); - return ret_val; - } - - std::vector InputGrad(const std::string& name) const { - return ToGradNames(fwd_op_.Input(name)); + std::transform( + var_names.begin(), var_names.end(), std::back_inserter(ret_val), + [this](const std::string& fwd_var_name) -> std::string { + auto g_name = GradVarName(fwd_var_name); + return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName; + }); + if (!drop_empty_grad) { + return ret_val; + } + std::vector dropped_ret_val; + dropped_ret_val.reserve(ret_val.size()); + std::copy_if(ret_val.begin(), ret_val.end(), + std::back_inserter(dropped_ret_val), + [](const std::string& str) { return str != kEmptyVarName; }); + return dropped_ret_val; } std::vector OutputGrad(const std::string& name) const { - return ToGradNames(fwd_op_.Output(name)); + std::vector ret_val; + auto onames = this->Output(name); + ret_val.reserve(onames.size()); + std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val), + GradVarName); + return ret_val; } std::vector InputNames() const { @@ -75,6 +94,7 @@ class GradOpDescMakerBase { private: const OpDescBind& fwd_op_; + const std::unordered_set& no_grad_set_; }; class SingleGradOpDescMaker : public GradOpDescMakerBase { @@ -91,6 +111,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { virtual std::unique_ptr Apply() const = 0; }; +template class DefaultGradOpDescMaker : public SingleGradOpDescMaker { public: using SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -102,7 +123,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { for (auto& input_param : this->InputNames()) { grad->SetInput(input_param, this->Input(input_param)); - grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param)); + grad->SetOutput(GradVarName(input_param), + this->InputGrad(input_param, DropEmptyIG)); } for (auto& output_param : this->OutputNames()) { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 94f75b0f309417e52bb5ad3117797d8c25233257..504afbd5dbacf7185f92e0000d19666230e2fb42 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -59,11 +59,5 @@ std::unique_ptr OpRegistry::CreateOp(const OpDescBind& op_desc) { op_desc.GetAttrMap()); } -std::vector> OpRegistry::CreateGradOpDescs( - OpDescBind* op_desc) { - auto& info = OpInfoMap::Instance().Get(op_desc->Type()); - return info.grad_op_maker_(*op_desc); -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 5ca3af52a6909eeee21f647d0e60c7a690f90190..226e8ddcd4b1a2630e0eea00ee6c9f6af6bd5d20 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -79,9 +79,6 @@ class OpRegistry { static std::unique_ptr CreateOp(const OpDesc& op_desc); - static std::vector> CreateGradOpDescs( - OpDescBind* op_desc); - static std::unique_ptr CreateOp(const OpDescBind& op_desc); }; @@ -160,17 +157,18 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ - grad_op_class) \ - REGISTER_OPERATOR(grad_op_type, grad_op_class); \ - class _GradOpDescMaker_##grad_op_type##_ \ - : public ::paddle::framework::DefaultGradOpDescMaker { \ - using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \ - \ - protected: \ - virtual std::string GradOpType() const { return #grad_op_type; } \ - }; \ - REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ + REGISTER_OPERATOR(grad_op_type, grad_op_class); \ + class _GradOpDescMaker_##grad_op_type##_ \ + : public ::paddle::framework::DefaultGradOpDescMaker { \ + using ::paddle::framework::DefaultGradOpDescMaker< \ + true>::DefaultGradOpDescMaker; \ + \ + protected: \ + virtual std::string GradOpType() const { return #grad_op_type; } \ + }; \ + REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ op_maker_class); #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index 6f65a942ba2a4073e6aa1047875ec5c3283c23a6..7e1b79c97b5b4c3f0292fbfa205a8ca541702fbc 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -36,8 +36,8 @@ using OpCreator = std::function; -using GradOpMakerFN = - std::function>(const OpDescBind&)>; +using GradOpMakerFN = std::function>( + const OpDescBind&, const std::unordered_set& /*no_grad_set*/)>; } // namespace framework } // namespace paddle diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index a86685b6dde4761cf74f9521bd9609b0864b9bdf..051051b051961c6da064bd9319460b3f41cea3e8 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -115,8 +115,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, - ops::MultiplexGradOp); +REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp); REGISTER_OP_CPU_KERNEL( multiplex, ops::MultiplexCPUKernel); REGISTER_OP_CPU_KERNEL(