From 3eebd5b3914a1278b51df3289e01749aeec3e84b Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 8 Oct 2019 10:26:40 +0800 Subject: [PATCH] refine sequence_softmax grad maker, test=develop (#20127) --- .../fluid/operators/sequence_ops/sequence_softmax_op.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 027073e5d7d..c0f993ac7aa 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -146,11 +146,14 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - ctx.Input("X")->type(), ctx.GetPlace(), + ctx.Input("Out")->type(), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequenceSoftmaxGradOpNoNeedBufferVarsInferer, "X"); + } // namespace operators } // namespace paddle @@ -158,7 +161,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_softmax, ops::SequenceSoftmaxOp, ops::SequenceSoftmaxOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_softmax_grad, ops::SequenceSoftmaxGradOp); +REGISTER_OPERATOR(sequence_softmax_grad, ops::SequenceSoftmaxGradOp, + ops::SequenceSoftmaxGradOpNoNeedBufferVarsInferer); REGISTER_OP_CPU_KERNEL( sequence_softmax, ops::SequenceSoftmaxKernel, -- GitLab