未验证 提交 c7251b96 编写于 作者: C csy0225 提交者: GitHub

[XPU] Argmax kernel output support int32 data type. (#51303)

上级 d7660a7c
...@@ -22,6 +22,11 @@ ...@@ -22,6 +22,11 @@
namespace phi { namespace phi {
namespace {
const int ARG_MAX_OUTPUT_DATATYPE_INT32 = 2;
const int ARG_MAX_OUTPUT_DATATYPE_INT64 = 3;
} // Anonymous namespace
template <typename T, typename Context> template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx, void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -31,7 +36,8 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -31,7 +36,8 @@ void ArgMaxKernel(const Context& dev_ctx,
int dtype, int dtype,
DenseTensor* out) { DenseTensor* out) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3), (dtype < 0 || dtype == ARG_MAX_OUTPUT_DATATYPE_INT32 ||
dtype == ARG_MAX_OUTPUT_DATATYPE_INT64),
true, true,
errors::InvalidArgument( errors::InvalidArgument(
"The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], " "The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
...@@ -41,15 +47,6 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -41,15 +47,6 @@ void ArgMaxKernel(const Context& dev_ctx,
DataType::INT32, DataType::INT32,
dtype)); dtype));
// TODO(ZHUI): fix dtype of out // TODO(ZHUI): fix dtype of out
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out->data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
return;
}
DDim x_dims; DDim x_dims;
int axis_val = axis.to<int>(); int axis_val = axis.to<int>();
if (flatten) { if (flatten) {
...@@ -61,17 +58,62 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -61,17 +58,62 @@ void ArgMaxKernel(const Context& dev_ctx,
if (axis_val < 0) axis_val += x_dims.size(); if (axis_val < 0) axis_val += x_dims.size();
} }
auto xdims_vec = phi::vectorize<int>(x_dims); auto xdims_vec = phi::vectorize<int>(x_dims);
int r = xpu::argmax(dev_ctx.x_context(), int r = 0;
if (dtype != ARG_MAX_OUTPUT_DATATYPE_INT32) {
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out->data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
return;
}
r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(),
out->data<int64_t>(),
xdims_vec,
axis_val);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
errors::External("XPU argmax kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
} else {
DenseTensor out_int64;
out_int64.Resize(out->dims());
dev_ctx.template Alloc<int64_t>(&out_int64);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out_int64.data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
} else {
r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(), x.data<T>(),
out->data<int64_t>(), out_int64.data<int64_t>(),
xdims_vec, xdims_vec,
axis_val); axis_val);
PADDLE_ENFORCE_EQ( }
r,
XPU_SUCCESS, PADDLE_ENFORCE_EQ(
errors::External("XPU argmax kernel return wrong value[%d %s].", r,
r, XPU_SUCCESS,
XPUAPIErrorMsg[r])); errors::External("XPU argmax kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
dev_ctx.template Alloc<int>(out);
r = xpu::cast_v2<int64_t, int>(dev_ctx.x_context(),
out_int64.data<int64_t>(),
out->data<int>(),
out_int64.numel());
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
errors::External("XPU cast kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
}
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) { PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册