diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 8f1d9284c503813ef3dd9688891048a5bca57b29..e0db2f26d3e0534f924cc709b98689fb3f1a5cc6 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -45,6 +45,8 @@ class GatherOpCUDAKernel : public framework::OpKernel { axis = static_cast(cpu_axis.data()[0]); } else if (axis_type == framework::proto::VarType::INT64) { axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT16) { + axis = static_cast(cpu_axis.data()[0]); } } const auto &place = ctx.GetPlace(); @@ -57,6 +59,9 @@ class GatherOpCUDAKernel : public framework::OpKernel { } else if (index_type == framework::proto::VarType::INT64) { phi::funcs::GatherV2CUDAFunction(x, index, axis, output, dev_ctx); + } else if (index_type == framework::proto::VarType::INT16) { + phi::funcs::GatherV2CUDAFunction(x, index, axis, output, + dev_ctx); } return; } @@ -67,6 +72,8 @@ class GatherOpCUDAKernel : public framework::OpKernel { phi::funcs::GPUGather(dev_ctx, *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { phi::funcs::GPUGather(dev_ctx, *x, *index, output); + } else if (index_type == framework::proto::VarType::INT16) { + phi::funcs::GPUGather(dev_ctx, *x, *index, output); } } }; @@ -134,6 +141,7 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, + ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel, diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index fbd6197c1b92ee8481a1ce6f4a2cec8482eaefb0..32ccecbc6d9f0282b86f100e1b910667fab41cb2 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1402,7 +1402,8 @@ def gather(x, index, axis=None, name=None): return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], + x, 'x', + ['float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather')