diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 21d2a602001ad834943ac67eb4c81bd039f15b6c..1161a53e16ebca78f5babf4ef12215debaa4c2c6 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -406,7 +406,10 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { support_low_precision = OpSupportPrecision( GetOpOriginalType(op_type), backend_, low_precision_, black_list_); - if (op_node->Op()->HasAttr("dtype")) { + std::unordered_set check_dtype_op_blacklist( + {"arg_max", "arg_min"}); + if (op_node->Op()->HasAttr("dtype") && + !check_dtype_op_blacklist.count(GetOpOriginalType(op_type))) { auto dtype = op_node->Op()->GetAttrIfExists("dtype"); support_low_precision = support_low_precision && diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index c3d8fc0819aed621683c48b7d4bf89c0dd456c5f..bb22e15d43c6a3335b92dbb97fdcc19eaf7e3c08 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -397,7 +397,8 @@ XPUOpMap& get_kl2_ops() { {"gather_nd", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, {"gather", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/kernels/xpu/gather_nd_kernel.cc b/paddle/phi/kernels/xpu/gather_nd_kernel.cc index c790d4d9de742cce58c66e5ee3a530c9329f4b36..9966d3795d5043723beb7eb62a16a91f1a9c218b 100644 --- a/paddle/phi/kernels/xpu/gather_nd_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_nd_kernel.cc @@ -24,6 +24,7 @@ void GatherNdKernel(const Context &ctx, const DenseTensor &x, const DenseTensor &index, DenseTensor *out) { + using XPUType = typename XPUTypeTrait::Type; ctx.template Alloc(out); if (x.numel() == 0) { @@ -57,8 +58,8 @@ void GatherNdKernel(const Context &ctx, // int broadcast(Context* ctx, const T* x, T* y, const std::vector& // xshape, const std::vector& yshape) int r = xpu::broadcast(ctx.x_context(), - x.data(), - out->data(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), {1, x_numel}, {remain_numel, x_numel}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); @@ -87,24 +88,32 @@ void GatherNdKernel(const Context &ctx, int ret = XPU_SUCCESS; if (index_type == DataType::INT32) { - ret = xpu::gather_nd(ctx.x_context(), - x.data(), - index.data(), - out->data(), - x_vec, - index_shape); + ret = xpu::gather_nd( + ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + x_vec, + index_shape); } else { - ret = xpu::gather_nd(ctx.x_context(), - x.data(), - index.data(), - out->data(), - x_vec, - index_shape); + ret = xpu::gather_nd( + ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + x_vec, + index_shape); } PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather_nd"); } } // namespace phi -PD_REGISTER_KERNEL( - gather_nd, XPU, ALL_LAYOUT, phi::GatherNdKernel, float, int64_t, int) {} +PD_REGISTER_KERNEL(gather_nd, + XPU, + ALL_LAYOUT, + phi::GatherNdKernel, + float, + int64_t, + int, + phi::dtype::float16) {}