未验证 提交 14094aad 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add FP16 support for arg_min_max (#55642)

上级 a7567cd0
......@@ -36,7 +36,10 @@ XPUOpMap& get_kl2_ops() {
{"adam_dense_param_sparse_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adagrad", XPUKernelSet({phi::DataType::FLOAT32})},
{"arg_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"arg_max",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"argsort_grad",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
......
......@@ -35,6 +35,7 @@ void ArgMaxKernel(const Context& dev_ctx,
bool flatten,
int dtype,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == ARG_MAX_OUTPUT_DATATYPE_INT32 ||
dtype == ARG_MAX_OUTPUT_DATATYPE_INT64),
......@@ -69,7 +70,7 @@ void ArgMaxKernel(const Context& dev_ctx,
return;
}
r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(),
reinterpret_cast<const XPUType*>(x.data<T>()),
out->data<int64_t>(),
xdims_vec,
axis_val);
......@@ -90,7 +91,7 @@ void ArgMaxKernel(const Context& dev_ctx,
static_cast<int64_t>(0));
} else {
r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(),
reinterpret_cast<const XPUType*>(x.data<T>()),
out_int64.data<int64_t>(),
xdims_vec,
axis_val);
......@@ -116,6 +117,12 @@ void ArgMaxKernel(const Context& dev_ctx,
}
}
} // namespace phi
PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {
PD_REGISTER_KERNEL(argmax,
XPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
float,
int,
phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册