未验证 提交 750abc2c 编写于 作者: A Aurelius84 提交者: GitHub

[XPU]Migrate argsort and arg_max XPU kernel into Phi (#45576)

* [XPU]Migrate argsort and arg_max XPU kernel into Phi

* test=kunlun

* test=kunlun
上级 a0e3a175
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/arg_min_max_op_base.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class ArgMaxXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto dtype = ctx.Attr<int>("dtype");
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
true,
platform::errors::InvalidArgument(
"The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
"but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
out->template mutable_data<int64_t>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
const bool& flatten = ctx.Attr<bool>("flatten");
framework::DDim x_dims;
if (flatten) {
x_dims = phi::make_ddim({x->numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
x_dims = x->dims();
if (axis < 0) axis += x_dims.size();
}
auto xdims_vec = phi::vectorize<int>(x_dims);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::argmax(dev_ctx.x_context(),
x->data<T>(),
out->data<int64_t>(),
xdims_vec,
axis);
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU argmax kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
arg_max, ops::ArgMaxXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
true,
errors::InvalidArgument(
"The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
"but "
"received [%s]",
DataType::INT64,
DataType::INT32,
dtype));
dev_ctx.template Alloc<int64_t>(out);
DDim x_dims;
int axis_val = axis.to<int>();
if (flatten) {
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
axis_val = 0;
} else {
x_dims = x.dims();
if (axis_val < 0) axis_val += x_dims.size();
}
auto xdims_vec = phi::vectorize<int>(x_dims);
int r = xpu::argmax(dev_ctx.x_context(),
x.data<T>(),
out->data<int64_t>(),
xdims_vec,
axis_val);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
errors::External("XPU argmax kernel return wrong value[%d %s].",
r,
XPUAPIErrorMsg[r]));
}
} // namespace phi
PD_REGISTER_KERNEL(arg_max, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#ifdef PADDLE_WITH_XPU #include "paddle/phi/kernels/argsort_kernel.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle {
namespace operators { namespace phi {
const int XPU_SORT_MAX_SIZE = 16384; const int XPU_SORT_MAX_SIZE = 16384;
...@@ -34,7 +34,7 @@ static inline void xpu_argsort(xpu::Context* ctx, ...@@ -34,7 +34,7 @@ static inline void xpu_argsort(xpu::Context* ctx,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, ret,
XPU_SUCCESS, XPU_SUCCESS,
platform::errors::External("XPU sort kernel return wrong value[%d %s].", errors::External("XPU sort kernel return wrong value[%d %s].",
ret, ret,
XPUAPIErrorMsg[ret])); XPUAPIErrorMsg[ret]));
} }
...@@ -46,10 +46,10 @@ static inline void xpu_transpose(xpu::Context* ctx, ...@@ -46,10 +46,10 @@ static inline void xpu_transpose(xpu::Context* ctx,
const std::vector<int>& xshape, const std::vector<int>& xshape,
const std::vector<int>& permute) { const std::vector<int>& permute) {
int ret = xpu::transpose(ctx, x, y, xshape, permute); int ret = xpu::transpose(ctx, x, y, xshape, permute);
PADDLE_ENFORCE_EQ(ret, PADDLE_ENFORCE_EQ(
ret,
XPU_SUCCESS, XPU_SUCCESS,
platform::errors::External( errors::External("XPU transpose kernel return wrong value[%d %s]",
"XPU transpose kernel return wrong value[%d %s]",
ret, ret,
XPUAPIErrorMsg[ret])); XPUAPIErrorMsg[ret]));
} }
...@@ -60,7 +60,7 @@ static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) { ...@@ -60,7 +60,7 @@ static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ret, ret,
XPU_SUCCESS, XPU_SUCCESS,
platform::errors::External("XPU cast kernel return wrong value[%d %s]", errors::External("XPU cast kernel return wrong value[%d %s]",
ret, ret,
XPUAPIErrorMsg[ret])); XPUAPIErrorMsg[ret]));
} }
...@@ -179,40 +179,34 @@ struct XPUArgsort<int64_t, true, true> { ...@@ -179,40 +179,34 @@ struct XPUArgsort<int64_t, true, true> {
} }
}; };
template <typename T> template <typename T, typename Context>
class ArgsortXPUKernel : public framework::OpKernel<T> { void ArgsortKernel(const Context& dev_ctx,
public: const DenseTensor& input,
void Compute(const framework::ExecutionContext& ctx) const override { int axis,
auto* input = ctx.Input<framework::Tensor>("X"); bool descending,
auto* output = ctx.Output<framework::Tensor>("Out"); DenseTensor* output,
auto* indices = ctx.Output<framework::Tensor>("Indices"); DenseTensor* indices) {
int axis = ctx.Attr<int>("axis"); auto in_dims = input.dims();
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int n = in_dims[axis]; int n = in_dims[axis];
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
n, n,
XPU_SORT_MAX_SIZE, XPU_SORT_MAX_SIZE,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The axis dimension of Input should less than %d, but got %d.", "The axis dimension of Input should less than %d, but got %d.",
XPU_SORT_MAX_SIZE, XPU_SORT_MAX_SIZE,
in_dims[axis])); in_dims[axis]));
auto input_data = input->data<T>(); auto input_data = input.data<T>();
auto output_data = output->mutable_data<T>(ctx.GetPlace()); auto output_data = dev_ctx.template Alloc<T>(output);
auto indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); auto indices_data = dev_ctx.template Alloc<int64_t>(indices);
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis));
int len_after = int len_after =
phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size()));
bool int64_need_cast = bool int64_need_cast =
(std::is_same<T, int64_t>::value && n > (XPU_SORT_MAX_SIZE / 2)) (std::is_same<T, int64_t>::value && n > (XPU_SORT_MAX_SIZE / 2)) ? true
? true
: false; : false;
bool index_need_cast = (n > (XPU_SORT_MAX_SIZE / 2)) ? true : false; bool index_need_cast = (n > (XPU_SORT_MAX_SIZE / 2)) ? true : false;
std::vector<int> permute_vec{0, 2, 1}; std::vector<int> permute_vec{0, 2, 1};
...@@ -243,18 +237,9 @@ class ArgsortXPUKernel : public framework::OpKernel<T> { ...@@ -243,18 +237,9 @@ class ArgsortXPUKernel : public framework::OpKernel<T> {
permute_vec, permute_vec,
descending); descending);
} }
} }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(argsort, } // namespace phi
ops::ArgsortXPUKernel<float>,
ops::ArgsortXPUKernel<int>,
ops::ArgsortXPUKernel<int64_t>);
#endif PD_REGISTER_KERNEL(
argsort, XPU, ALL_LAYOUT, phi::ArgsortKernel, float, int, int64_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册