未验证 提交 3e56e816 编写于 作者: L Li Min 提交者: GitHub

Add support of int16 for gather op. (#40052)

* add support of int16 for gather op.

* Recover formats.

* Recover formats.

* fix.

* Fix format.

* Fix format.
上级 9f74b84e
......@@ -45,6 +45,8 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT16) {
axis = static_cast<int>(cpu_axis.data<int16_t>()[0]);
}
}
const auto &place = ctx.GetPlace();
......@@ -57,6 +59,9 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GatherV2CUDAFunction<T, int64_t>(x, index, axis, output,
dev_ctx);
} else if (index_type == framework::proto::VarType::INT16) {
phi::funcs::GatherV2CUDAFunction<T, int16_t>(x, index, axis, output,
dev_ctx);
}
return;
}
......@@ -67,6 +72,8 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
phi::funcs::GPUGather<T, int>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GPUGather<T, int64_t>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT16) {
phi::funcs::GPUGather<T, int16_t>(dev_ctx, *x, *index, output);
}
}
};
......@@ -134,6 +141,7 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>,
ops::GatherOpCUDAKernel<int16_t>,
ops::GatherOpCUDAKernel<plat::float16>,
ops::GatherOpCUDAKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
......
......@@ -1402,7 +1402,8 @@ def gather(x, index, axis=None, name=None):
return _C_ops.gather(x, index, None, "axis", axis, "overwrite", False)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
x, 'x',
['float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'gather')
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册