diff --git a/paddle/fluid/operators/controlflow/compare_op_xpu.cc b/paddle/fluid/operators/controlflow/compare_op_xpu.cc deleted file mode 100644 index a62c64fd8d550ea90055b7893d1eb1cfe4ff5790..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/controlflow/compare_op_xpu.cc +++ /dev/null @@ -1,152 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" - -namespace paddle { -namespace operators { - -template -void XPUCompare(const framework::ExecutionContext& ctx, - std::function&, - const std::vector&)> func) { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - - auto x_shape = phi::vectorize(x->dims()); - auto y_shape = phi::vectorize(y->dims()); - - auto x_data = reinterpret_cast(x->data()); - auto y_data = reinterpret_cast(y->data()); - auto z_data = z->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = - ctx.template device_context(); - - int ret = func(dev_ctx.x_context(), x_data, y_data, z_data, x_shape, y_shape); - PADDLE_ENFORCE_EQ( - ret, - xpu::SUCCESS, - platform::errors::External( - "XPU kernel compare op occur error[%d %s] in XPUCompare.", - ret, - XPUAPIErrorMsg[ret])); -} - -template -class EqualXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - XPUCompare(ctx, xpu::broadcast_equal); - } -}; - -template -class NotEqualXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - XPUCompare(ctx, xpu::broadcast_not_equal); - } -}; - -template -class LessThanXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - XPUCompare(ctx, xpu::broadcast_less_than); - } -}; - -template -class LessEqualXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - XPUCompare(ctx, xpu::broadcast_less_equal); - } -}; - -template -class GreaterThanXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - XPUCompare(ctx, xpu::broadcast_greater_than); - } -}; - -template -class GreaterEqualXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - XPUCompare(ctx, xpu::broadcast_greater_equal); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_XPU_KERNEL(equal, - ops::EqualXPUKernel, - ops::EqualXPUKernel, - ops::EqualXPUKernel); - -REGISTER_OP_XPU_KERNEL(not_equal, - ops::NotEqualXPUKernel, - ops::NotEqualXPUKernel, - ops::NotEqualXPUKernel); - -REGISTER_OP_XPU_KERNEL(less_than, - ops::LessThanXPUKernel, - ops::LessThanXPUKernel, - ops::LessThanXPUKernel); - -REGISTER_OP_XPU_KERNEL( - less_equal, - ops::LessEqualXPUKernel, - ops::LessEqualXPUKernel, - ops::LessEqualXPUKernel); - -REGISTER_OP_XPU_KERNEL( - greater_than, - ops::GreaterThanXPUKernel, - ops::GreaterThanXPUKernel, - ops::GreaterThanXPUKernel); - -REGISTER_OP_XPU_KERNEL( - greater_equal, - ops::GreaterEqualXPUKernel, - ops::GreaterEqualXPUKernel, - ops::GreaterEqualXPUKernel); - -#endif diff --git a/paddle/phi/kernels/xpu/compare_kernel.cc b/paddle/phi/kernels/xpu/compare_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..32866e7aa701e7b124a1f22043078fe8b9811233 --- /dev/null +++ b/paddle/phi/kernels/xpu/compare_kernel.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/compare_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/backends/xpu/xpu_header.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void XPUCompareKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out, + std::function&, + const std::vector&)> func) { + auto x_shape = vectorize(x.dims()); + auto y_shape = vectorize(y.dims()); + + auto x_data = reinterpret_cast(x.data()); + auto y_data = reinterpret_cast(y.data()); + auto* out_data = dev_ctx.template Alloc(out); + + int ret = + func(dev_ctx.x_context(), x_data, y_data, out_data, x_shape, y_shape); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "compare op"); +} + +#define DEFINE_XPU_COMPARE_KERNEL(compare_kernel, functor) \ + template \ + void compare_kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out) { \ + using XPUType = typename XPUTypeTrait::Type; \ + XPUCompareKernelImpl(dev_ctx, x, y, out, functor); \ + } + +DEFINE_XPU_COMPARE_KERNEL(EqualKernel, xpu::broadcast_equal) +DEFINE_XPU_COMPARE_KERNEL(NotEqualKernel, xpu::broadcast_not_equal) +DEFINE_XPU_COMPARE_KERNEL(LessThanKernel, xpu::broadcast_less_than) +DEFINE_XPU_COMPARE_KERNEL(LessEqualKernel, xpu::broadcast_less_equal) +DEFINE_XPU_COMPARE_KERNEL(GreaterThanKernel, + xpu::broadcast_greater_than) +DEFINE_XPU_COMPARE_KERNEL(GreaterEqualKernel, + xpu::broadcast_greater_equal) +#undef DEFINE_XPU_COMPARE_KERNEL + +} // namespace phi + +PD_REGISTER_KERNEL( + equal, XPU, ALL_LAYOUT, phi::EqualKernel, float, int, int64_t) {} +PD_REGISTER_KERNEL( + not_equal, XPU, ALL_LAYOUT, phi::NotEqualKernel, float, int, int64_t) {} +PD_REGISTER_KERNEL( + less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, float, int, int64_t) {} +PD_REGISTER_KERNEL( + less_equal, XPU, ALL_LAYOUT, phi::LessEqualKernel, float, int, int64_t) {} +PD_REGISTER_KERNEL(greater_than, + XPU, + ALL_LAYOUT, + phi::GreaterThanKernel, + float, + int, + int64_t) {} +PD_REGISTER_KERNEL(greater_equal, + XPU, + ALL_LAYOUT, + phi::GreaterEqualKernel, + float, + int, + int64_t) {}