From 3e56e8167f634e67005b864ad56970bcc6cc3048 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Thu, 3 Mar 2022 13:03:44 +0800 Subject: [PATCH] Add support of int16 for gather op. (#40052) * add support of int16 for gather op. * Recover formats. * Recover formats. * fix. * Fix format. * Fix format. --- paddle/fluid/operators/gather_op.cu | 8 ++++++++ python/paddle/tensor/manipulation.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 8f1d9284c5..e0db2f26d3 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -45,6 +45,8 @@ class GatherOpCUDAKernel : public framework::OpKernel { axis = static_cast(cpu_axis.data()[0]); } else if (axis_type == framework::proto::VarType::INT64) { axis = static_cast(cpu_axis.data()[0]); + } else if (axis_type == framework::proto::VarType::INT16) { + axis = static_cast(cpu_axis.data()[0]); } } const auto &place = ctx.GetPlace(); @@ -57,6 +59,9 @@ class GatherOpCUDAKernel : public framework::OpKernel { } else if (index_type == framework::proto::VarType::INT64) { phi::funcs::GatherV2CUDAFunction(x, index, axis, output, dev_ctx); + } else if (index_type == framework::proto::VarType::INT16) { + phi::funcs::GatherV2CUDAFunction(x, index, axis, output, + dev_ctx); } return; } @@ -67,6 +72,8 @@ class GatherOpCUDAKernel : public framework::OpKernel { phi::funcs::GPUGather(dev_ctx, *x, *index, output); } else if (index_type == framework::proto::VarType::INT64) { phi::funcs::GPUGather(dev_ctx, *x, *index, output); + } else if (index_type == framework::proto::VarType::INT16) { + phi::funcs::GPUGather(dev_ctx, *x, *index, output); } } }; @@ -134,6 +141,7 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, + ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel, diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index fbd6197c1b..32ccecbc6d 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1402,7 +1402,8 @@ def gather(x, index, axis=None, name=None): return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], + x, 'x', + ['float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') -- GitLab