From 307128d13c592d36f31c65e61fb31943c7e16e8b Mon Sep 17 00:00:00 2001 From: jiangfan06 <117341294+MuShangCC@users.noreply.github.com> Date: Thu, 10 Aug 2023 19:27:37 +0800 Subject: [PATCH] [XPU] Add gather_nd fp16 and add check_dtype_op_blacklist (#55860) --- .../framework/ir/auto_mixed_precision_pass.cc | 5 ++- paddle/phi/backends/xpu/xpu2_op_list.cc | 3 +- paddle/phi/kernels/xpu/gather_nd_kernel.cc | 41 +++++++++++-------- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 21d2a602001..1161a53e16e 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -406,7 +406,10 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { support_low_precision = OpSupportPrecision( GetOpOriginalType(op_type), backend_, low_precision_, black_list_); - if (op_node->Op()->HasAttr("dtype")) { + std::unordered_set check_dtype_op_blacklist( + {"arg_max", "arg_min"}); + if (op_node->Op()->HasAttr("dtype") && + !check_dtype_op_blacklist.count(GetOpOriginalType(op_type))) { auto dtype = op_node->Op()->GetAttrIfExists("dtype"); support_low_precision = support_low_precision && diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index c3d8fc0819a..bb22e15d43c 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -397,7 +397,8 @@ XPUOpMap& get_kl2_ops() { {"gather_nd", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, {"gather", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/kernels/xpu/gather_nd_kernel.cc b/paddle/phi/kernels/xpu/gather_nd_kernel.cc index c790d4d9de7..9966d3795d5 100644 --- a/paddle/phi/kernels/xpu/gather_nd_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_nd_kernel.cc @@ -24,6 +24,7 @@ void GatherNdKernel(const Context &ctx, const DenseTensor &x, const DenseTensor &index, DenseTensor *out) { + using XPUType = typename XPUTypeTrait::Type; ctx.template Alloc(out); if (x.numel() == 0) { @@ -57,8 +58,8 @@ void GatherNdKernel(const Context &ctx, // int broadcast(Context* ctx, const T* x, T* y, const std::vector& // xshape, const std::vector& yshape) int r = xpu::broadcast(ctx.x_context(), - x.data(), - out->data(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), {1, x_numel}, {remain_numel, x_numel}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); @@ -87,24 +88,32 @@ void GatherNdKernel(const Context &ctx, int ret = XPU_SUCCESS; if (index_type == DataType::INT32) { - ret = xpu::gather_nd(ctx.x_context(), - x.data(), - index.data(), - out->data(), - x_vec, - index_shape); + ret = xpu::gather_nd( + ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + x_vec, + index_shape); } else { - ret = xpu::gather_nd(ctx.x_context(), - x.data(), - index.data(), - out->data(), - x_vec, - index_shape); + ret = xpu::gather_nd( + ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + x_vec, + index_shape); } PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather_nd"); } } // namespace phi -PD_REGISTER_KERNEL( - gather_nd, XPU, ALL_LAYOUT, phi::GatherNdKernel, float, int64_t, int) {} +PD_REGISTER_KERNEL(gather_nd, + XPU, + ALL_LAYOUT, + phi::GatherNdKernel, + float, + int64_t, + int, + phi::dtype::float16) {} -- GitLab