diff --git a/paddle/fluid/operators/arg_max_op_xpu.cc b/paddle/fluid/operators/arg_max_op_xpu.cc deleted file mode 100644 index 1077a73a827129bcdde37db69025ff469710d5cc..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/arg_max_op_xpu.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/operators/arg_min_max_op_base.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -template -class ArgMaxXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - auto dtype = ctx.Attr("dtype"); - PADDLE_ENFORCE_EQ( - (dtype < 0 || dtype == 2 || dtype == 3), - true, - platform::errors::InvalidArgument( - "The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], " - "but " - "received [%s]", - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - static_cast(dtype)))); - - out->template mutable_data(ctx.GetPlace()); - auto axis = ctx.Attr("axis"); - const bool& flatten = ctx.Attr("flatten"); - framework::DDim x_dims; - if (flatten) { - x_dims = phi::make_ddim({x->numel()}); - // if flatten, the axis just as 0 - axis = 0; - } else { - x_dims = x->dims(); - if (axis < 0) axis += x_dims.size(); - } - auto xdims_vec = phi::vectorize(x_dims); - auto& dev_ctx = ctx.template device_context(); - int r = xpu::argmax(dev_ctx.x_context(), - x->data(), - out->data(), - xdims_vec, - axis); - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU argmax kernel return wrong value[%d %s].", - r, - XPUAPIErrorMsg[r])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - arg_max, ops::ArgMaxXPUKernel); - -#endif diff --git a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..a48e2155a251a4c14fb09bc0d29391c1fab8a589 --- /dev/null +++ b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/arg_min_max_kernel.h" + +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ArgMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& axis, + bool keepdims, + bool flatten, + int dtype, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + (dtype < 0 || dtype == 2 || dtype == 3), + true, + errors::InvalidArgument( + "The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], " + "but " + "received [%s]", + DataType::INT64, + DataType::INT32, + dtype)); + dev_ctx.template Alloc(out); + + DDim x_dims; + int axis_val = axis.to(); + if (flatten) { + x_dims = phi::make_ddim({x.numel()}); + // if flatten, the axis just as 0 + axis_val = 0; + } else { + x_dims = x.dims(); + if (axis_val < 0) axis_val += x_dims.size(); + } + auto xdims_vec = phi::vectorize(x_dims); + int r = xpu::argmax(dev_ctx.x_context(), + x.data(), + out->data(), + xdims_vec, + axis_val); + PADDLE_ENFORCE_EQ( + r, + XPU_SUCCESS, + errors::External("XPU argmax kernel return wrong value[%d %s].", + r, + XPUAPIErrorMsg[r])); +} +} // namespace phi +PD_REGISTER_KERNEL(arg_max, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {} diff --git a/paddle/fluid/operators/argsort_op_xpu.cc b/paddle/phi/kernels/xpu/argsort_kernel.cc similarity index 58% rename from paddle/fluid/operators/argsort_op_xpu.cc rename to paddle/phi/kernels/xpu/argsort_kernel.cc index 95837841cce244b8588e34e8fb510449413d1f32..80db142e15d01f78fd8be50e30df620387697812 100644 --- a/paddle/fluid/operators/argsort_op_xpu.cc +++ b/paddle/phi/kernels/xpu/argsort_kernel.cc @@ -1,23 +1,23 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/argsort_kernel.h" + +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { const int XPU_SORT_MAX_SIZE = 16384; @@ -34,9 +34,9 @@ static inline void xpu_argsort(xpu::Context* ctx, PADDLE_ENFORCE_EQ( ret, XPU_SUCCESS, - platform::errors::External("XPU sort kernel return wrong value[%d %s].", - ret, - XPUAPIErrorMsg[ret])); + errors::External("XPU sort kernel return wrong value[%d %s].", + ret, + XPUAPIErrorMsg[ret])); } template @@ -46,12 +46,12 @@ static inline void xpu_transpose(xpu::Context* ctx, const std::vector& xshape, const std::vector& permute) { int ret = xpu::transpose(ctx, x, y, xshape, permute); - PADDLE_ENFORCE_EQ(ret, - XPU_SUCCESS, - platform::errors::External( - "XPU transpose kernel return wrong value[%d %s]", - ret, - XPUAPIErrorMsg[ret])); + PADDLE_ENFORCE_EQ( + ret, + XPU_SUCCESS, + errors::External("XPU transpose kernel return wrong value[%d %s]", + ret, + XPUAPIErrorMsg[ret])); } template @@ -60,9 +60,9 @@ static inline void xpu_cast(xpu::Context* ctx, const TX* x, TY* y, int len) { PADDLE_ENFORCE_EQ( ret, XPU_SUCCESS, - platform::errors::External("XPU cast kernel return wrong value[%d %s]", - ret, - XPUAPIErrorMsg[ret])); + errors::External("XPU cast kernel return wrong value[%d %s]", + ret, + XPUAPIErrorMsg[ret])); } template { } }; -template -class ArgsortXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - auto* indices = ctx.Output("Indices"); - int axis = ctx.Attr("axis"); - bool descending = ctx.Attr("descending"); - - auto in_dims = input->dims(); - axis = (axis < 0) ? (in_dims.size() + axis) : axis; - int n = in_dims[axis]; - - PADDLE_ENFORCE_LT( - n, - XPU_SORT_MAX_SIZE, - platform::errors::InvalidArgument( - "The axis dimension of Input should less than %d, but got %d.", - XPU_SORT_MAX_SIZE, - in_dims[axis])); - - auto input_data = input->data(); - auto output_data = output->mutable_data(ctx.GetPlace()); - auto indices_data = indices->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = - ctx.template device_context(); - int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); - int len_after = - phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); - bool int64_need_cast = - (std::is_same::value && n > (XPU_SORT_MAX_SIZE / 2)) - ? true - : false; - bool index_need_cast = (n > (XPU_SORT_MAX_SIZE / 2)) ? true : false; - std::vector permute_vec{0, 2, 1}; - std::vector data_shape{len_before, n, len_after}; - - if (int64_need_cast) { - XPUArgsort()(dev_ctx.x_context(), +template +void ArgsortKernel(const Context& dev_ctx, + const DenseTensor& input, + int axis, + bool descending, + DenseTensor* output, + DenseTensor* indices) { + auto in_dims = input.dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + int n = in_dims[axis]; + + PADDLE_ENFORCE_LT( + n, + XPU_SORT_MAX_SIZE, + errors::InvalidArgument( + "The axis dimension of Input should less than %d, but got %d.", + XPU_SORT_MAX_SIZE, + in_dims[axis])); + + auto input_data = input.data(); + auto output_data = dev_ctx.template Alloc(output); + auto indices_data = dev_ctx.template Alloc(indices); + + int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); + int len_after = + phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); + bool int64_need_cast = + (std::is_same::value && n > (XPU_SORT_MAX_SIZE / 2)) ? true + : false; + bool index_need_cast = (n > (XPU_SORT_MAX_SIZE / 2)) ? true : false; + std::vector permute_vec{0, 2, 1}; + std::vector data_shape{len_before, n, len_after}; + + if (int64_need_cast) { + XPUArgsort()(dev_ctx.x_context(), + input_data, + output_data, + indices_data, + data_shape, + permute_vec, + descending); + } else if (index_need_cast) { + XPUArgsort()(dev_ctx.x_context(), + input_data, + output_data, + indices_data, + data_shape, + permute_vec, + descending); + } else { + XPUArgsort()(dev_ctx.x_context(), input_data, output_data, indices_data, data_shape, permute_vec, descending); - } else if (index_need_cast) { - XPUArgsort()(dev_ctx.x_context(), - input_data, - output_data, - indices_data, - data_shape, - permute_vec, - descending); - } else { - XPUArgsort()(dev_ctx.x_context(), - input_data, - output_data, - indices_data, - data_shape, - permute_vec, - descending); - } } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; +} -REGISTER_OP_XPU_KERNEL(argsort, - ops::ArgsortXPUKernel, - ops::ArgsortXPUKernel, - ops::ArgsortXPUKernel); +} // namespace phi -#endif +PD_REGISTER_KERNEL( + argsort, XPU, ALL_LAYOUT, phi::ArgsortKernel, float, int, int64_t) {}