未验证 提交 c7251b96 编写于 作者: C csy0225 提交者: GitHub

[XPU] Argmax kernel output support int32 data type. (#51303)

上级 d7660a7c
......@@ -22,6 +22,11 @@
namespace phi {
namespace {
const int ARG_MAX_OUTPUT_DATATYPE_INT32 = 2;
const int ARG_MAX_OUTPUT_DATATYPE_INT64 = 3;
} // Anonymous namespace
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -31,7 +36,8 @@ void ArgMaxKernel(const Context& dev_ctx,
int dtype,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
(dtype < 0 || dtype == ARG_MAX_OUTPUT_DATATYPE_INT32 ||
dtype == ARG_MAX_OUTPUT_DATATYPE_INT64),
true,
errors::InvalidArgument(
"The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
......@@ -41,15 +47,6 @@ void ArgMaxKernel(const Context& dev_ctx,
DataType::INT32,
dtype));
// TODO(ZHUI): fix dtype of out
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out->data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
return;
}
DDim x_dims;
int axis_val = axis.to<int>();
if (flatten) {
......@@ -61,17 +58,62 @@ void ArgMaxKernel(const Context& dev_ctx,
if (axis_val < 0) axis_val += x_dims.size();
}
auto xdims_vec = phi::vectorize<int>(x_dims);
int r = xpu::argmax(dev_ctx.x_context(),
int r = 0;
if (dtype != ARG_MAX_OUTPUT_DATATYPE_INT32) {
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out->data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
return;
}
r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(),
out->data<int64_t>(),
xdims_vec,
axis_val);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
errors::External("XPU argmax kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
} else {
DenseTensor out_int64;
out_int64.Resize(out->dims());
dev_ctx.template Alloc<int64_t>(&out_int64);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out_int64.data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
} else {
r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(),
out->data<int64_t>(),
out_int64.data<int64_t>(),
xdims_vec,
axis_val);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
errors::External("XPU argmax kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
}
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
errors::External("XPU argmax kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
dev_ctx.template Alloc<int>(out);
r = xpu::cast_v2<int64_t, int>(dev_ctx.x_context(),
out_int64.data<int64_t>(),
out->data<int>(),
out_int64.numel());
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
errors::External("XPU cast kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
}
}
} // namespace phi
PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册