diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 34fd11e8c0d0d0a83698416145d9ea2e37a181ca..648afe7e8215fea4c6875c61f224110c6327516a 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -69,7 +69,11 @@ class GatherOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string& var_name, const framework::Tensor& tensor, const framework::OpKernelType& expected_kernel_type) const override { - return expected_kernel_type; + if (var_name == "Axis") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); } };