diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 027073e5d7d6c767ebb02662c6fd8b2cf9306904..c0f993ac7aa26d962b16dc6bab985d01ff7bc1a8 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -146,11 +146,14 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - ctx.Input("X")->type(), ctx.GetPlace(), + ctx.Input("Out")->type(), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + SequenceSoftmaxGradOpNoNeedBufferVarsInferer, "X"); + } // namespace operators } // namespace paddle @@ -158,7 +161,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sequence_softmax, ops::SequenceSoftmaxOp, ops::SequenceSoftmaxOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(sequence_softmax_grad, ops::SequenceSoftmaxGradOp); +REGISTER_OPERATOR(sequence_softmax_grad, ops::SequenceSoftmaxGradOp, + ops::SequenceSoftmaxGradOpNoNeedBufferVarsInferer); REGISTER_OP_CPU_KERNEL( sequence_softmax, ops::SequenceSoftmaxKernel,