未验证 提交 14094aad 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add FP16 support for arg_min_max (#55642)

上级 a7567cd0
...@@ -36,7 +36,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -36,7 +36,10 @@ XPUOpMap& get_kl2_ops() {
{"adam_dense_param_sparse_grad", {"adam_dense_param_sparse_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adagrad", XPUKernelSet({phi::DataType::FLOAT32})}, {"adagrad", XPUKernelSet({phi::DataType::FLOAT32})},
{"arg_max", XPUKernelSet({phi::DataType::FLOAT32})}, {"arg_max",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"argsort_grad", {"argsort_grad",
XPUKernelSet({phi::DataType::INT32, XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64, phi::DataType::INT64,
......
...@@ -35,6 +35,7 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -35,6 +35,7 @@ void ArgMaxKernel(const Context& dev_ctx,
bool flatten, bool flatten,
int dtype, int dtype,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == ARG_MAX_OUTPUT_DATATYPE_INT32 || (dtype < 0 || dtype == ARG_MAX_OUTPUT_DATATYPE_INT32 ||
dtype == ARG_MAX_OUTPUT_DATATYPE_INT64), dtype == ARG_MAX_OUTPUT_DATATYPE_INT64),
...@@ -69,7 +70,7 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -69,7 +70,7 @@ void ArgMaxKernel(const Context& dev_ctx,
return; return;
} }
r = xpu::argmax(dev_ctx.x_context(), r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(), reinterpret_cast<const XPUType*>(x.data<T>()),
out->data<int64_t>(), out->data<int64_t>(),
xdims_vec, xdims_vec,
axis_val); axis_val);
...@@ -90,7 +91,7 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -90,7 +91,7 @@ void ArgMaxKernel(const Context& dev_ctx,
static_cast<int64_t>(0)); static_cast<int64_t>(0));
} else { } else {
r = xpu::argmax(dev_ctx.x_context(), r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(), reinterpret_cast<const XPUType*>(x.data<T>()),
out_int64.data<int64_t>(), out_int64.data<int64_t>(),
xdims_vec, xdims_vec,
axis_val); axis_val);
...@@ -116,6 +117,12 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -116,6 +117,12 @@ void ArgMaxKernel(const Context& dev_ctx,
} }
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) { PD_REGISTER_KERNEL(argmax,
XPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
float,
int,
phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册