未验证 提交 307128d1 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add gather_nd fp16 and add check_dtype_op_blacklist (#55860)

上级 b546b923
......@@ -406,7 +406,10 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
support_low_precision = OpSupportPrecision(
GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
if (op_node->Op()->HasAttr("dtype")) {
std::unordered_set<std::string> check_dtype_op_blacklist(
{"arg_max", "arg_min"});
if (op_node->Op()->HasAttr("dtype") &&
!check_dtype_op_blacklist.count(GetOpOriginalType(op_type))) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision =
support_low_precision &&
......
......@@ -397,7 +397,8 @@ XPUOpMap& get_kl2_ops() {
{"gather_nd",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"gather",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
......
......@@ -24,6 +24,7 @@ void GatherNdKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &index,
DenseTensor *out) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(out);
if (x.numel() == 0) {
......@@ -57,8 +58,8 @@ void GatherNdKernel(const Context &ctx,
// int broadcast(Context* ctx, const T* x, T* y, const std::vector<int>&
// xshape, const std::vector<int>& yshape)
int r = xpu::broadcast(ctx.x_context(),
x.data<T>(),
out->data<T>(),
reinterpret_cast<const XPUType *>(x.data<T>()),
reinterpret_cast<XPUType *>(out->data<T>()),
{1, x_numel},
{remain_numel, x_numel});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
......@@ -87,17 +88,19 @@ void GatherNdKernel(const Context &ctx,
int ret = XPU_SUCCESS;
if (index_type == DataType::INT32) {
ret = xpu::gather_nd<T, int>(ctx.x_context(),
x.data<T>(),
ret = xpu::gather_nd<XPUType, int>(
ctx.x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
index.data<int>(),
out->data<T>(),
reinterpret_cast<XPUType *>(out->data<T>()),
x_vec,
index_shape);
} else {
ret = xpu::gather_nd<T, int64_t>(ctx.x_context(),
x.data<T>(),
ret = xpu::gather_nd<XPUType, int64_t>(
ctx.x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
index.data<int64_t>(),
out->data<T>(),
reinterpret_cast<XPUType *>(out->data<T>()),
x_vec,
index_shape);
}
......@@ -106,5 +109,11 @@ void GatherNdKernel(const Context &ctx,
} // namespace phi
PD_REGISTER_KERNEL(
gather_nd, XPU, ALL_LAYOUT, phi::GatherNdKernel, float, int64_t, int) {}
PD_REGISTER_KERNEL(gather_nd,
XPU,
ALL_LAYOUT,
phi::GatherNdKernel,
float,
int64_t,
int,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册