未验证 提交 307128d1 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add gather_nd fp16 and add check_dtype_op_blacklist (#55860)

上级 b546b923
...@@ -406,7 +406,10 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { ...@@ -406,7 +406,10 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
support_low_precision = OpSupportPrecision( support_low_precision = OpSupportPrecision(
GetOpOriginalType(op_type), backend_, low_precision_, black_list_); GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
if (op_node->Op()->HasAttr("dtype")) { std::unordered_set<std::string> check_dtype_op_blacklist(
{"arg_max", "arg_min"});
if (op_node->Op()->HasAttr("dtype") &&
!check_dtype_op_blacklist.count(GetOpOriginalType(op_type))) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype"); auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision = support_low_precision =
support_low_precision && support_low_precision &&
......
...@@ -397,7 +397,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -397,7 +397,8 @@ XPUOpMap& get_kl2_ops() {
{"gather_nd", {"gather_nd",
XPUKernelSet({phi::DataType::INT32, XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"gather", {"gather",
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
......
...@@ -24,6 +24,7 @@ void GatherNdKernel(const Context &ctx, ...@@ -24,6 +24,7 @@ void GatherNdKernel(const Context &ctx,
const DenseTensor &x, const DenseTensor &x,
const DenseTensor &index, const DenseTensor &index,
DenseTensor *out) { DenseTensor *out) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
if (x.numel() == 0) { if (x.numel() == 0) {
...@@ -57,8 +58,8 @@ void GatherNdKernel(const Context &ctx, ...@@ -57,8 +58,8 @@ void GatherNdKernel(const Context &ctx,
// int broadcast(Context* ctx, const T* x, T* y, const std::vector<int>& // int broadcast(Context* ctx, const T* x, T* y, const std::vector<int>&
// xshape, const std::vector<int>& yshape) // xshape, const std::vector<int>& yshape)
int r = xpu::broadcast(ctx.x_context(), int r = xpu::broadcast(ctx.x_context(),
x.data<T>(), reinterpret_cast<const XPUType *>(x.data<T>()),
out->data<T>(), reinterpret_cast<XPUType *>(out->data<T>()),
{1, x_numel}, {1, x_numel},
{remain_numel, x_numel}); {remain_numel, x_numel});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
...@@ -87,17 +88,19 @@ void GatherNdKernel(const Context &ctx, ...@@ -87,17 +88,19 @@ void GatherNdKernel(const Context &ctx,
int ret = XPU_SUCCESS; int ret = XPU_SUCCESS;
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
ret = xpu::gather_nd<T, int>(ctx.x_context(), ret = xpu::gather_nd<XPUType, int>(
x.data<T>(), ctx.x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
index.data<int>(), index.data<int>(),
out->data<T>(), reinterpret_cast<XPUType *>(out->data<T>()),
x_vec, x_vec,
index_shape); index_shape);
} else { } else {
ret = xpu::gather_nd<T, int64_t>(ctx.x_context(), ret = xpu::gather_nd<XPUType, int64_t>(
x.data<T>(), ctx.x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
index.data<int64_t>(), index.data<int64_t>(),
out->data<T>(), reinterpret_cast<XPUType *>(out->data<T>()),
x_vec, x_vec,
index_shape); index_shape);
} }
...@@ -106,5 +109,11 @@ void GatherNdKernel(const Context &ctx, ...@@ -106,5 +109,11 @@ void GatherNdKernel(const Context &ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(gather_nd,
gather_nd, XPU, ALL_LAYOUT, phi::GatherNdKernel, float, int64_t, int) {} XPU,
ALL_LAYOUT,
phi::GatherNdKernel,
float,
int64_t,
int,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册