From d658244997ed359e05bb3070f0a3efb5c4d708aa Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 2 Apr 2019 01:48:44 -0500 Subject: [PATCH] fix some grad op desc maker (#16581) test=develop --- .../fluid/op_use_default_grad_op_maker.spec | 12 ----- paddle/fluid/operators/affine_grid_op.cc | 7 ++- .../operators/bilinear_tensor_product_op.cc | 43 ++++++++++++---- paddle/fluid/operators/group_norm_op.cc | 4 +- paddle/fluid/operators/norm_op.cc | 24 ++++++++- paddle/fluid/operators/pad2d_op.cc | 19 +++++-- .../sequence_ops/sequence_concat_op.cc | 41 +++++++++++++-- .../sequence_ops/sequence_concat_op.h | 36 +++++++++---- .../sequence_ops/sequence_conv_op.cc | 51 ++++++++++++++++++- .../sequence_ops/sequence_expand_as_op.cc | 45 ++++++++++++++-- .../sequence_ops/sequence_expand_op.cc | 43 ++++++++++++++-- .../operators/sequence_ops/sequence_pad_op.cc | 29 +++++++++-- .../sequence_ops/sequence_pool_op.cc | 12 +++-- .../sequence_ops/sequence_scatter_op.cc | 35 +++++++++++-- .../sequence_ops/sequence_slice_op.cc | 33 ++++++++++-- .../sequence_ops/sequence_unpad_op.cc | 30 +++++++++-- .../sequence_ops/sequence_unpad_op.h | 3 +- paddle/fluid/operators/shuffle_channel_op.cc | 14 ++--- paddle/fluid/operators/shuffle_channel_op.cu | 13 ++--- paddle/fluid/operators/shuffle_channel_op.h | 12 ++--- .../sigmoid_cross_entropy_with_logits_op.cc | 23 ++++++++- paddle/fluid/operators/slice_op.cc | 14 ++++- paddle/fluid/operators/temporal_shift_op.cc | 33 ++++++++---- ..._deletion_no_need_buffer_vars_inference.py | 45 ++++++++++++++++ 24 files changed, 511 insertions(+), 110 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_eager_deletion_no_need_buffer_vars_inference.py diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index 568db2cf0..712449f6b 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -3,7 +3,6 @@ acos asin atan attention_lstm -bilinear_tensor_product brelu conv_shift cos @@ -43,7 +42,6 @@ max_pool3d_with_index maxout modified_huber_loss nce -norm pool2d pool3d pow @@ -60,16 +58,7 @@ reshape rnn_memory_helper round row_conv -sequence_concat -sequence_conv -sequence_expand -sequence_expand_as -sequence_pad -sequence_scatter -sequence_slice sequence_softmax -sequence_unpad -sigmoid_cross_entropy_with_logits sin softplus softshrink @@ -84,7 +73,6 @@ stanh swish tanh_shrink teacher_student_sigmoid_loss -temporal_shift tensor_array_to_tensor thresholded_relu transpose diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 1de59a516..9d7100cc3 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/affine_grid_op.h" +#include #include +#include #include "paddle/fluid/framework/op_registry.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" @@ -173,9 +175,10 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - auto theta_dims = ctx->GetInputDim("Theta"); if (ctx->HasOutput(framework::GradVarName("Theta"))) { - ctx->SetOutputDim(framework::GradVarName("Theta"), theta_dims); + auto output_dims = ctx->GetInputDim(framework::GradVarName("Output")); + ctx->SetOutputDim(framework::GradVarName("Theta"), + {output_dims[0], 2, 3}); } } diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.cc b/paddle/fluid/operators/bilinear_tensor_product_op.cc index 8d261a118..bd69f422e 100644 --- a/paddle/fluid/operators/bilinear_tensor_product_op.cc +++ b/paddle/fluid/operators/bilinear_tensor_product_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/bilinear_tensor_product_op.h" +#include +#include +#include namespace paddle { namespace operators { @@ -121,15 +124,9 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { "The second dimension of input(Out@GRAD) must be equal to " "the third dimension of the Input(Weight)."); - if (ctx->HasInput("Bias")) { - auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ( - bias_dims[1], out_dims[1], - "The second dimension of input(Out@GRAD) must be equal to " - "the second dimension of the Input(Bias)."); - auto bias_grad_name = framework::GradVarName("Bias"); - if (ctx->HasOutput(bias_grad_name)) - ctx->SetOutputDim(bias_grad_name, bias_dims); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) { + ctx->SetOutputDim(bias_grad_name, {1, out_dims[1]}); } auto x_grad_name = framework::GradVarName("X"); @@ -148,13 +145,39 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { } }; +class BilinearTensorProductGradOpDescMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("bilinear_tensor_product_grad"); + op->SetAttrMap(Attrs()); + op->SetInput("X", Input("X")); + op->SetInput("Y", Input("Y")); + op->SetInput("Weight", Input("Weight")); + if (ForwardOp().Inputs().count("Bias") > 0) { + op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); + } + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); + op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + + return op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(bilinear_tensor_product, ops::BilinearTensorProductOp, ops::BilinearTensorProductOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::BilinearTensorProductGradOpDescMaker); REGISTER_OPERATOR(bilinear_tensor_product_grad, ops::BilinearTensorProductOpGrad); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 2ab40f482..09fd6a25d 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include +#include namespace paddle { namespace operators { @@ -107,8 +108,6 @@ class GroupNormGradOp : public framework::OperatorWithKernel { // check input PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of GroupNormOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Mean"), - "Input(Mean) of GroupNormOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Variance"), "Input(Variance) of GroupNormOp should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), @@ -159,7 +158,6 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker { op->SetInput("Bias", Input("Bias")); op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput("Y", Output("Y")); - op->SetInput("Mean", Output("Mean")); op->SetInput("Variance", Output("Variance")); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); diff --git a/paddle/fluid/operators/norm_op.cc b/paddle/fluid/operators/norm_op.cc index aa19c62c8..81fbe3e51 100644 --- a/paddle/fluid/operators/norm_op.cc +++ b/paddle/fluid/operators/norm_op.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/norm_op.h" +#include +#include +#include + namespace paddle { namespace operators { @@ -74,6 +78,24 @@ class NormOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; + +class NormOpGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("norm_grad"); + op->SetAttrMap(Attrs()); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetInput("Norm", Output("Norm")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return op; + } +}; + } // namespace operators } // namespace paddle @@ -81,7 +103,7 @@ namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::NormOpGradOpDescMaker); REGISTER_OPERATOR(norm_grad, ops::NormOpGrad); REGISTER_OP_CPU_KERNEL(norm, ops::NormKernel, ops::NormKernel); diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index 6ef2dacb3..9731aefa9 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include +#include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -612,8 +615,9 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); } }; @@ -625,7 +629,9 @@ class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker { std::unique_ptr Apply() const override { auto* bind = new framework::OpDesc(); bind->SetInput("X", Input("X")); - bind->SetInput("Paddings", Input("Paddings")); + if (ForwardOp().Inputs().count("Paddings") > 0) { + bind->SetInput("Paddings", Input("Paddings")); + } bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); bind->SetAttrMap(Attrs()); @@ -634,6 +640,10 @@ class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker { } }; +// TODO(zjl): Paddings can also be skipped! +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(Pad2dOpGradNoNeedBufferVarsInference, + "X"); + } // namespace operators } // namespace paddle @@ -641,6 +651,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker, ops::Pad2dOpGradMaker); -REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad); +REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad, + ops::Pad2dOpGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel); REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel); diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index 37f1b9dda..d652f9216 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/sequence_ops/sequence_concat_op.h" +#include #include namespace paddle { @@ -73,13 +74,43 @@ class SeqConcatShapeInferer : public framework::InferShapeBase { } }; -class SeqConcatGradShapeInferer : public framework::InferShapeBase { +class SeqConcatGradOpDescMaker : public framework::SingleGradOpDescMaker { public: - void operator()(framework::InferShapeContext *context) const override { + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_concat_grad"); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X", false)); + op->SetAttrMap(Attrs()); + return op; + } +}; + +class SeqConcatGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { context->SetOutputsDim(framework::GradVarName("X"), context->GetInputsDim("X")); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); + } }; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SeqConcatGradNoNeedBufferVarsInference, + "X"); + } // namespace operators } // namespace paddle @@ -87,14 +118,14 @@ namespace op = paddle::operators; REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel, op::SeqConcatOpMaker, op::SeqConcatShapeInferer, - paddle::framework::DefaultGradOpDescMaker); + op::SeqConcatGradOpDescMaker); template using Kernel = op::SeqConcatKernel; REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel, Kernel, Kernel); -REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel, - op::SeqConcatGradShapeInferer); +REGISTER_OPERATOR(sequence_concat_grad, op::SeqConcatGradOp, + op::SeqConcatGradNoNeedBufferVarsInference); template using GradKernel = op::SeqConcatGradKernel; diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.h b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h index ff035f421..f9b2ed384 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include "boost/optional.hpp" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/concat_and_split.h" @@ -89,37 +91,49 @@ class SeqConcatGradKernel : public framework::OpKernel { dxs[i]->mutable_data(context.GetPlace()); } } + std::vector sliced_x; - std::vector> sliced_dx; + std::vector> sliced_dx; for (size_t i = 1; i < xs[0]->lod()[0].size(); ++i) { for (size_t j = 0; j < xs.size(); ++j) { const framework::LoDTensor *x = xs[j]; + framework::DDim x_dims = x->dims(); + framework::LoDTensor *dx = dxs[j]; auto &x_lod = x->lod()[0]; - sliced_x.emplace_back(x->Slice(x_lod[i - 1], x_lod[i])); - if (dx != nullptr) { - sliced_dx.emplace_back(dx->Slice(x_lod[i - 1], x_lod[i])); + + auto prev_lod = x_lod[i - 1]; + auto next_lod = x_lod[i]; + + x_dims[0] = next_lod - prev_lod; + + sliced_x.emplace_back(); + sliced_x.back().Resize(x_dims); + + if (dx) { + sliced_dx.emplace_back(dx->Slice(prev_lod, next_lod)); } else { - sliced_dx.emplace_back(boost::blank()); + sliced_dx.emplace_back(boost::none); } } } - math::SplitFunctor functor; std::vector sliced_x_ptr; - std::vector sliced_dx_ptr; + sliced_x_ptr.reserve(sliced_x.size()); for (auto &x : sliced_x) { sliced_x_ptr.emplace_back(&x); } + std::vector sliced_dx_ptr; + sliced_dx_ptr.reserve(sliced_dx.size()); for (auto &dx : sliced_dx) { - try { - sliced_dx_ptr.emplace_back(&boost::get(dx)); - } catch (boost::bad_get &) { - sliced_dx_ptr.emplace_back(nullptr); + if (dx) { + sliced_dx_ptr.emplace_back(&dx.get()); } } + + math::SplitFunctor functor; functor(context.template device_context(), detail::Ref( context.Input(framework::GradVarName("Out")), diff --git a/paddle/fluid/operators/sequence_ops/sequence_conv_op.cc b/paddle/fluid/operators/sequence_ops/sequence_conv_op.cc index 65cd9edbc..89c1fe834 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_conv_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_conv_op.cc @@ -15,6 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_conv_op.h" #include +#include +#include +#include namespace paddle { namespace operators { @@ -171,13 +174,57 @@ context_length, context_stride and context_start. } }; +class SequenceConvGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_conv_grad"); + op->SetAttrMap(Attrs()); + + if (boost::get(Attrs().at("paddingTrainable")) && + ForwardOp().Inputs().count("PaddingData") > 0) { + op->SetInput("PaddingData", Input("PaddingData")); + op->SetOutput(framework::GradVarName("PaddingData"), + InputGrad("PaddingData")); + } + + op->SetInput("X", Input("X")); + op->SetInput("Filter", Input("Filter")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); + + return op; + } +}; + +class SequenceConvGradNoNeedBufferVarsInference + : public framework::NoNeedBufferVarsInference { + public: + using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference; + + std::unordered_set operator()() const override { + if (!boost::get(Attrs().at("paddingTrainable"))) { + return {"PaddingData"}; + } else { + return {}; + } + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_conv_grad, ops::SequenceConvGradOp); + ops::SequenceConvGradOpDescMaker); + +REGISTER_OPERATOR(sequence_conv_grad, ops::SequenceConvGradOp, + ops::SequenceConvGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( sequence_conv, diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc index 3b79d0c71..e1f6c3e3d 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h" +#include +#include namespace paddle { namespace operators { @@ -70,6 +72,12 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("Y", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } }; class SequenceExpandAsOpMaker : public framework::OpProtoAndCheckerMaker { @@ -131,7 +139,6 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null."); @@ -143,16 +150,48 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { ctx->ShareLoD("X", x_grad_name); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); + } }; +class SequenceExpandAsOpGradOpDescMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_expand_as_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Y", Input("Y")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequenceExpandAsOpNoNeedBufferVarsInference, "Y"); +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequenceExpandAsGradOpNoNeedBufferVarsInference, "X", "Y"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_expand_as, ops::SequenceExpandAsOp, ops::SequenceExpandAsOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad); + ops::SequenceExpandAsOpGradOpDescMaker, + ops::SequenceExpandAsOpNoNeedBufferVarsInference); +REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad, + ops::SequenceExpandAsGradOpNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( sequence_expand_as, ops::SequenceExpandAsKernel, diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc index f6c424153..b7c042063 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h" +#include namespace paddle { namespace operators { @@ -96,6 +97,12 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } }; class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { @@ -188,7 +195,6 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null."); @@ -199,16 +205,47 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(x_grad_name, x_dims); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); + } }; +class SequenceExpandOpGradDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_expand_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Y", Input("Y")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SequenceExpandOpNoNeedBufferVarsInference, + "Y"); +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequenceExpandGradOpNoNeedBufferVarsInference, "X", "Y"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad); + ops::SequenceExpandOpGradDescMaker, + ops::SequenceExpandOpNoNeedBufferVarsInference); +REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad, + ops::SequenceExpandGradOpNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( sequence_expand, ops::SequenceExpandKernel, diff --git a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc index 23c7bf7ce..5290d0e6c 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h" +#include +#include namespace paddle { namespace operators { @@ -194,18 +196,39 @@ class SequencePadGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = framework::GetDataTypeOfVar( + ctx.InputVar(framework::GradVarName("Out"))); return framework::OpKernelType(data_type, ctx.device_context()); } }; +class SequencePadGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_pad_grad"); + op->SetAttrMap(Attrs()); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequencePadGradOpNoNeedBufferVarsInference, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp); + ops::SequencePadGradOpDescMaker); +REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp, + ops::SequencePadGradOpNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( sequence_pad, ops::SequencePadOpKernel, diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 1754221e7..b4923571d 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" +#include #include namespace paddle { @@ -114,8 +115,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); } }; @@ -138,13 +140,17 @@ class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequencePoolGradOpNoNeedBufferVarsInference, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, ops::SequencePoolGradOpMaker); -REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp); +REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp, + ops::SequencePoolGradOpNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( sequence_pool, ops::SequencePoolKernel); diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index 8267c04f9..5a22212ed 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_scatter_op.h" +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/gather.h" @@ -124,25 +125,49 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("Updates"), ctx->GetInputDim("Updates")); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + platform::CPUPlace()); } }; +class SequenceScatterGradDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_scatter_grad"); + op->SetInput("Ids", Input("Ids")); + op->SetInput("Updates", Input("Updates")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Updates"), InputGrad("Updates")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequenceScatterGradNoNeedBufferVarsInference, "Updates"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp, ops::SequenceScatterOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp); + ops::SequenceScatterGradDescMaker); +REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp, + ops::SequenceScatterGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel, ops::SequenceScatterOpKernel, ops::SequenceScatterOpKernel, diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc index 35f49f78c..4b2ec6e7c 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h" +#include namespace paddle { namespace operators { @@ -70,8 +71,9 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); } }; @@ -113,14 +115,35 @@ NOTE: The first dimension size of input, the size of offset and Length, should b } }; +class SequenceSliceGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_slice_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Offset", Input("Offset")); + op->SetInput("Length", Input("Length")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequenceSliceGradNoNeedBufferVarsInference, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_slice, ops::SequenceSliceOp, - ops::SequenceSliceOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp); + ops::SequenceSliceOpMaker, ops::SequenceSliceGradOpDescMaker); +REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp, + ops::SequenceSliceGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( sequence_slice, ops::SequenceSliceOpKernel); diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc index 2cf508e0b..6c98a3e87 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h" +#include +#include namespace paddle { namespace operators { @@ -125,19 +127,39 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = framework::GetDataTypeOfVar( + ctx.InputVar(framework::GradVarName("Out"))); return framework::OpKernelType(data_type, ctx.device_context()); } }; +class SequenceUnpadGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_unpad_grad"); + op->SetAttrMap(Attrs()); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequenceUnpadGradOpNoNeedBufferVarsInference, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp, - ops::SequenceUnpadOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp); + ops::SequenceUnpadOpMaker, ops::SequenceUnpadGradOpDescMaker); +REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp, + ops::SequenceUnpadGradOpNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( sequence_unpad, ops::SequenceUnpadOpKernel, diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h index 07df3dca8..fe8ca41b6 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h @@ -81,10 +81,9 @@ class SequenceUnpadGradOpKernel : public framework::OpKernel { auto* d_x = ctx.Output(framework::GradVarName("X")); if (d_x) { const auto* d_out = ctx.Input(framework::GradVarName("Out")); - const auto* x_t = ctx.Input("X"); d_x->mutable_data(ctx.GetPlace()); - int padded_length = x_t->dims()[1]; + int padded_length = d_x->dims()[1]; LoDTensor zero_pads; zero_pads.Resize({1, 1}); diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index 26355e586..ad6fb3510 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -11,6 +11,7 @@ limitations under the License. */ #include "paddle/fluid/operators/shuffle_channel_op.h" #include +#include namespace paddle { namespace operators { @@ -73,12 +74,7 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@Grad) should not be null"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Output(X@Grad) should not be null"); - - auto input_dims = ctx->GetInputDim("X"); + auto input_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); ctx->SetOutputDim(framework::GradVarName("X"), input_dims); @@ -87,8 +83,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); } }; @@ -100,7 +97,6 @@ class ShuffleChannelGradDescMaker : public framework::SingleGradOpDescMaker { std::unique_ptr Apply() const override { std::unique_ptr op(new framework::OpDesc()); op->SetType("shuffle_channel_grad"); - op->SetInput("X", Input("X")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetAttrMap(Attrs()); diff --git a/paddle/fluid/operators/shuffle_channel_op.cu b/paddle/fluid/operators/shuffle_channel_op.cu index 9506343b3..dbc3e1a7e 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cu +++ b/paddle/fluid/operators/shuffle_channel_op.cu @@ -78,10 +78,14 @@ template class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); + auto* output_grad = + ctx.Input(framework::GradVarName("Out")); + auto* input_grad = + ctx.Output(framework::GradVarName("X")); + int group = ctx.Attr("group"); - auto input_dims = input->dims(); + const auto& input_dims = input_grad->dims(); auto num = input_dims[0]; auto channel = input_dims[1]; auto height = input_dims[2]; @@ -91,10 +95,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel { int group_row = group; int group_column = channel / group_row; - auto* output_grad = - ctx.Input(framework::GradVarName("Out")); - auto* input_grad = - ctx.Output(framework::GradVarName("X")); + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); const T* output_grad_data = output_grad->data(); diff --git a/paddle/fluid/operators/shuffle_channel_op.h b/paddle/fluid/operators/shuffle_channel_op.h index f6af1bc88..3ce1e0c77 100644 --- a/paddle/fluid/operators/shuffle_channel_op.h +++ b/paddle/fluid/operators/shuffle_channel_op.h @@ -57,10 +57,14 @@ template class ShuffleChannelGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); + auto* output_grad = + ctx.Input(framework::GradVarName("Out")); + auto* input_grad = + ctx.Output(framework::GradVarName("X")); + int group = ctx.Attr("group"); - auto input_dims = input->dims(); + const auto& input_dims = input_grad->dims(); auto num = input_dims[0]; auto channel = input_dims[1]; auto height = input_dims[2]; @@ -71,10 +75,6 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel { int group_row = group; int group_column = channel / group_row; - auto* output_grad = - ctx.Input(framework::GradVarName("Out")); - auto* input_grad = - ctx.Output(framework::GradVarName("X")); T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); const T* output_grad_data = output_grad->data(); for (int n = 0; n < num; ++n) { diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc index c21b0c13c..5c92588cc 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" +#include +#include +#include namespace paddle { namespace operators { @@ -139,6 +142,24 @@ However the output only shares the LoD with input `X`. } }; +class SigmoidCrossEntropyWithLogitsGradOpDescMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sigmoid_cross_entropy_with_logits_grad"); + op->SetInput("X", Input("X")); + op->SetInput("Label", Input("Label")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + } // namespace operators } // namespace paddle @@ -146,7 +167,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits, ops::SigmoidCrossEntropyWithLogitsOp, ops::SigmoidCrossEntropyWithLogitsOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker); REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad, ops::SigmoidCrossEntropyWithLogitsGradOp); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 94995fc99..589c98e51 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/slice_op.h" #include +#include #include namespace paddle { @@ -135,6 +136,13 @@ class SliceOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(x_grad_name, x_dims); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); + } }; class SliceOpGradMaker : public framework::SingleGradOpDescMaker { @@ -153,13 +161,17 @@ class SliceOpGradMaker : public framework::SingleGradOpDescMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SliceOpGradNoNeedBufferVarsInference, + "Input"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, ops::SliceOpGradMaker); -REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad); +REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad, + ops::SliceOpGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL( slice, ops::SliceKernel, diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index 7df649fc5..3b7d90b79 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -10,6 +10,9 @@ limitations under the License. */ #include "paddle/fluid/operators/temporal_shift_op.h" +#include +#include +#include #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -125,19 +128,32 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - auto dim_x = ctx->GetInputDim("X"); if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); } } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); + } +}; + +class TemporalShiftGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("temporal_shift_grad"); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; } }; @@ -146,8 +162,7 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, - ops::TemporalShiftOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::TemporalShiftOpMaker, ops::TemporalShiftGradOpDescMaker); REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad); REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel, ops::TemporalShiftKernel); diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_no_need_buffer_vars_inference.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_no_need_buffer_vars_inference.py new file mode 100644 index 000000000..4844d930d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_no_need_buffer_vars_inference.py @@ -0,0 +1,45 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import importlib + +fluid.core._set_eager_deletion_mode(0.0, 1.0, True) + +from test_elementwise_add_op import * +from test_elementwise_sub_op import * +from test_concat_op import * +from test_gather_op import * +from test_gaussian_random_batch_size_like_op import * +from test_lod_reset_op import * +from test_scatter_op import * +from test_mean_op import * +from test_slice_op import * +from test_linear_chain_crf_op import * +from test_bilinear_interp_op import * +from test_nearest_interp_op import * +from test_sequence_concat import * +from test_seq_conv import * +from test_seq_pool import * +from test_sequence_expand_as import * +from test_sequence_expand import * +from test_sequence_pad_op import * +from test_sequence_unpad_op import * +from test_sequence_scatter_op import * +from test_sequence_slice_op import * +from test_pad2d_op import * + +if __name__ == '__main__': + unittest.main() -- GitLab