diff --git a/paddle/fluid/operators/size_op.cc b/paddle/fluid/operators/size_op.cc index 84b0f403be03893810ef592db9b2c993cc6b9644..4af355bfca64157301458a6b75b5734802f0082a 100644 --- a/paddle/fluid/operators/size_op.cc +++ b/paddle/fluid/operators/size_op.cc @@ -23,6 +23,19 @@ namespace operators { class SizeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = framework::proto::VarType::FP32; // dtype is not important + return framework::OpKernelType(dtype, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return expected_kernel_type; + } }; class SizeOpMaker : public framework::OpProtoAndCheckerMaker { @@ -40,6 +53,8 @@ Return the number of elements in the input. } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERER(SizeOpNoNeedBufferVarInferer, "Input"); + } // namespace operators } // namespace paddle @@ -50,4 +65,4 @@ REGISTER_OPERATOR( size, ops::SizeOp, ops::SizeOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - SizeInferShapeFunctor); + SizeInferShapeFunctor, ops::SizeOpNoNeedBufferVarInferer);