From c52fe48f6ffd62cbdf707a93b54c3f3df5547a06 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 11 Nov 2020 15:49:39 +0800 Subject: [PATCH] fix the GetKernelTypeForVar of input for fluid.gather (#28534) --- paddle/fluid/operators/gather_op.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 34fd11e8c0..648afe7e82 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -69,7 +69,11 @@ class GatherOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string& var_name, const framework::Tensor& tensor, const framework::OpKernelType& expected_kernel_type) const override { - return expected_kernel_type; + if (var_name == "Axis") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); } }; -- GitLab