未验证 提交 3e56e816 编写于 作者: L Li Min 提交者: GitHub

Add support of int16 for gather op. (#40052)

* add support of int16 for gather op.

* Recover formats.

* Recover formats.

* fix.

* Fix format.

* Fix format.
上级 9f74b84e
...@@ -45,6 +45,8 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -45,6 +45,8 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]); axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) { } else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]); axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT16) {
axis = static_cast<int>(cpu_axis.data<int16_t>()[0]);
} }
} }
const auto &place = ctx.GetPlace(); const auto &place = ctx.GetPlace();
...@@ -57,6 +59,9 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -57,6 +59,9 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GatherV2CUDAFunction<T, int64_t>(x, index, axis, output, phi::funcs::GatherV2CUDAFunction<T, int64_t>(x, index, axis, output,
dev_ctx); dev_ctx);
} else if (index_type == framework::proto::VarType::INT16) {
phi::funcs::GatherV2CUDAFunction<T, int16_t>(x, index, axis, output,
dev_ctx);
} }
return; return;
} }
...@@ -67,6 +72,8 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -67,6 +72,8 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
phi::funcs::GPUGather<T, int>(dev_ctx, *x, *index, output); phi::funcs::GPUGather<T, int>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GPUGather<T, int64_t>(dev_ctx, *x, *index, output); phi::funcs::GPUGather<T, int64_t>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT16) {
phi::funcs::GPUGather<T, int16_t>(dev_ctx, *x, *index, output);
} }
} }
}; };
...@@ -134,6 +141,7 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>, ...@@ -134,6 +141,7 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>, ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>, ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>, ops::GatherOpCUDAKernel<int>,
ops::GatherOpCUDAKernel<int16_t>,
ops::GatherOpCUDAKernel<plat::float16>, ops::GatherOpCUDAKernel<plat::float16>,
ops::GatherOpCUDAKernel<plat::bfloat16>); ops::GatherOpCUDAKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
......
...@@ -1402,7 +1402,8 @@ def gather(x, index, axis=None, name=None): ...@@ -1402,7 +1402,8 @@ def gather(x, index, axis=None, name=None):
return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False) return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], x, 'x',
['float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'gather') 'gather')
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册