From 0969a4eb192e61388eee315dd54469138e1ce1ea Mon Sep 17 00:00:00 2001 From: From00 Date: Thu, 3 Mar 2022 16:22:11 +0800 Subject: [PATCH] Move compare OPs to phi (#39970) * Move compare OPs to phi * Fix bug * Use BroadcastKernel and ElementwiseKernel in phi --- .../operators/controlflow/CMakeLists.txt | 2 +- .../operators/controlflow/compare_all_op.cc | 81 +-------- .../operators/controlflow/compare_all_op.cu | 92 ---------- .../operators/controlflow/compare_all_op.h | 43 ----- .../fluid/operators/controlflow/compare_op.cc | 79 +++------ .../fluid/operators/controlflow/compare_op.cu | 63 ------- .../fluid/operators/controlflow/compare_op.h | 109 ------------ .../operators/controlflow/compare_op_npu.cc | 2 +- .../operators/controlflow/compare_op_xpu.cc | 2 +- paddle/fluid/operators/matrix_rank_op.cc | 9 +- paddle/fluid/operators/matrix_rank_op.cu | 5 +- paddle/fluid/operators/matrix_rank_op.h | 1 - .../operators/metrics/accuracy_op_npu.cc | 2 +- paddle/fluid/operators/viterbi_decode_op.h | 31 ++-- paddle/phi/infermeta/binary.cc | 49 ++++++ paddle/phi/infermeta/binary.h | 9 + paddle/phi/kernels/compare_kernel.h | 47 ++++++ paddle/phi/kernels/cpu/compare_kernel.cc | 143 ++++++++++++++++ paddle/phi/kernels/funcs/compare_functors.h | 53 ++++++ paddle/phi/kernels/gpu/compare_kernel.cu | 158 ++++++++++++++++++ paddle/phi/kernels/impl/compare_kernel_impl.h | 81 +++++++++ paddle/phi/ops/compat/compare_sig.cc | 56 +++++++ 22 files changed, 654 insertions(+), 463 deletions(-) delete mode 100644 paddle/fluid/operators/controlflow/compare_all_op.cu delete mode 100644 paddle/fluid/operators/controlflow/compare_all_op.h delete mode 100644 paddle/fluid/operators/controlflow/compare_op.cu delete mode 100644 paddle/fluid/operators/controlflow/compare_op.h create mode 100644 paddle/phi/kernels/compare_kernel.h create mode 100644 paddle/phi/kernels/cpu/compare_kernel.cc create mode 100644 paddle/phi/kernels/funcs/compare_functors.h create mode 100644 paddle/phi/kernels/gpu/compare_kernel.cu create mode 100644 paddle/phi/kernels/impl/compare_kernel_impl.h create mode 100644 paddle/phi/ops/compat/compare_sig.cc diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index a974f2ec33..70937069d9 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -19,6 +19,6 @@ else() target_link_libraries(conditional_block_infer_op conditional_block_op) endif() -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n") +file(APPEND ${pybind_file} "USE_OP_ITSELF(less_than);\nUSE_OP_ITSELF(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP_ITSELF(logical_and);\nUSE_OP_ITSELF(logical_or);\nUSE_OP_ITSELF(logical_xor);\nUSE_OP_ITSELF(logical_not);\n") file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n") diff --git a/paddle/fluid/operators/controlflow/compare_all_op.cc b/paddle/fluid/operators/controlflow/compare_all_op.cc index ede349f737..9f229e6f15 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.cc +++ b/paddle/fluid/operators/controlflow/compare_all_op.cc @@ -12,49 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/controlflow/compare_all_op.h" -#include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { -template -class CompareReduceOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; - using Tensor = framework::Tensor; - - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* z = context.Output("Out"); - Tensor tmp; - bool* z_data = z->mutable_data(context.GetPlace()); - - if (x->dims() != y->dims()) { - z_data[0] = false; - } else { - tmp.mutable_data(x->dims(), context.GetPlace()); - if (x->numel() == 1 && y->numel() == 1) { - bool* z_data = tmp.mutable_data(context.GetPlace()); - z_data[0] = Functor()(x->data()[0], y->data()[0]); - } else { - ElementwiseComputeEx( - context, x, y, 0, Functor(), &tmp); - } - auto ipt = framework::EigenVector::Flatten(tmp); - auto out = framework::EigenScalar::From(*z); - auto& place = - *context.template device_context() - .eigen_device(); - auto reduce_dim = Eigen::array({{0}}); - out.device(place) = ipt.all(reduce_dim); - } - } -}; - template class CompareReduceOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: @@ -81,26 +46,6 @@ template class CompareReduceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* context) const override { - OpComment comment; - PADDLE_ENFORCE_EQ(context->HasInput("X"), true, - platform::errors::InvalidArgument( - "%s operator must have input X", comment.type)); - PADDLE_ENFORCE_EQ(context->HasInput("Y"), true, - platform::errors::InvalidArgument( - "%s operator must have input Y", comment.type)); - auto dim_x = context->GetInputDim("X"); - auto dim_y = context->GetInputDim("Y"); - PADDLE_ENFORCE_GE( - dim_x.size(), dim_y.size(), - platform::errors::InvalidArgument( - "The size of dim_y should not be greater than dim_x's.")); - - context->SetOutputDim("Out", {1}); - context->ShareLoD("X", "Out"); - } }; } // namespace operators @@ -113,25 +58,13 @@ class CompareReduceOp : public framework::OperatorWithKernel { }; \ char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::equation[]{_equation}; \ + DELCARE_INFER_SHAPE_FUNCTOR(op_type, op_type##_InferShapeFunctor, \ + PT_INFER_META(phi::CompareAllInferMeta)); \ REGISTER_OPERATOR( \ op_type, ::paddle::operators::CompareReduceOp<_##op_type##Comment>, \ ::paddle::operators::CompareReduceOpProtoMaker<_##op_type##Comment>, \ ::paddle::framework::EmptyGradOpMaker, \ - ::paddle::framework::EmptyGradOpMaker); + ::paddle::framework::EmptyGradOpMaker, \ + op_type##_InferShapeFunctor); -#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \ - REGISTER_OP_CPU_KERNEL( \ - op_type, ::paddle::operators::CompareReduceOpKernel< \ - ::paddle::platform::CPUDeviceContext, functor>, \ - ::paddle::operators::CompareReduceOpKernel< \ - ::paddle::platform::CPUDeviceContext, functor>, \ - ::paddle::operators::CompareReduceOpKernel< \ - ::paddle::platform::CPUDeviceContext, functor>, \ - ::paddle::operators::CompareReduceOpKernel< \ - ::paddle::platform::CPUDeviceContext, functor>, \ - ::paddle::operators::CompareReduceOpKernel< \ - ::paddle::platform::CPUDeviceContext, functor>); REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y"); - -REGISTER_COMPARE_REDUCE_CPU_KERNEL(equal_all, - paddle::operators::EqualReduceFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_all_op.cu b/paddle/fluid/operators/controlflow/compare_all_op.cu deleted file mode 100644 index d96dcebe51..0000000000 --- a/paddle/fluid/operators/controlflow/compare_all_op.cu +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "paddle/fluid/operators/controlflow/compare_all_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" - -namespace paddle { -namespace operators { - -template -struct BitwiseAdd { - // Bitwise add operator, returns a + b - inline T initial() { return static_cast(true); } - - __host__ __device__ __forceinline__ T operator()(const T& a, - const T& b) const { - return a & b; - } -}; - -template -class CompareReduceOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; - using Tensor = framework::Tensor; - - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* z = context.Output("Out"); - bool* z_data = z->mutable_data(context.GetPlace()); - Tensor tmp; - - if (x->dims() != y->dims()) { - thrust::device_ptr z_dev_ptr(z_data); - thrust::fill(z_dev_ptr, z_dev_ptr + 1, false); - return; - } else { - tmp.mutable_data(x->dims(), context.GetPlace()); - const auto& cuda_ctx = - context.template device_context(); - std::vector ins = {x, y}; - std::vector outs = {&tmp}; - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, &outs, Functor()); - - // Reduce by 'bitwise and' operator - std::vector reduce_dims; - reduce_dims.resize(tmp.dims().size()); - for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i; - auto stream = context.cuda_device_context().stream(); - TensorReduceImpl>( - context.cuda_device_context(), tmp, z, kps::IdentityFunctor(), - reduce_dims, stream); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ - REGISTER_OP_CUDA_KERNEL( \ - op_type, \ - ops::CompareReduceOpKernel>, \ - ops::CompareReduceOpKernel>, \ - ops::CompareReduceOpKernel>, \ - ops::CompareReduceOpKernel>, \ - ops::CompareReduceOpKernel>); - -REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, EqualReduceFunctor) -#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL diff --git a/paddle/fluid/operators/controlflow/compare_all_op.h b/paddle/fluid/operators/controlflow/compare_all_op.h deleted file mode 100644 index 78a7b76e3f..0000000000 --- a/paddle/fluid/operators/controlflow/compare_all_op.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/transform.h" - -namespace paddle { -namespace operators { - -template -struct EqualReduceFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T a, const T b) const { - if (std::is_floating_point::value) { - // This branch will be optimized while compiling if T is integer. It is - // safe to cast a and b to double. - return fabs(static_cast(a - b)) < 1e-8; - } else { - return (a == b); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index 657e74398b..5d9cdc6176 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/controlflow/compare_op.h" -#include -#include -#include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -60,31 +58,6 @@ class CompareOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* context) const override { - OpComment comment; - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type); - OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type); - auto dim_x = context->GetInputDim("X"); - auto dim_y = context->GetInputDim("Y"); - - if (context->GetInputDim("X") == context->GetInputDim("Y")) { - context->ShareDim("X", /*->*/ "Out"); - context->ShareLoD("X", /*->*/ "Out"); - } else { - int max_dim = std::max(dim_x.size(), dim_y.size()); - int axis = std::abs(dim_x.size() - dim_y.size()); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); - GetBroadcastDimsArrays(dim_x, dim_y, x_dims_array.data(), - y_dims_array.data(), out_dims_array.data(), - max_dim, axis); - context->SetOutputDim("Out", phi::make_ddim(out_dims_array)); - // to do - context->ShareLoD("X", /*->*/ "Out"); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); @@ -116,37 +89,31 @@ class CompareOp : public framework::OperatorWithKernel { "In order to force fill output variable to gpu memory.", \ false)); -#define REGISTER_COMPARE_OP(op_type, _equation) \ - struct _##op_type##Comment { \ - static char type[]; \ - static char equation[]; \ - }; \ - char _##op_type##Comment::type[]{#op_type}; \ - char _##op_type##Comment::equation[]{_equation}; \ - REGISTER_OPERATOR( \ - op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \ - ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ - ::paddle::framework::EmptyGradOpMaker, \ - ::paddle::framework::EmptyGradOpMaker); \ +#define REGISTER_COMPARE_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + DELCARE_INFER_SHAPE_FUNCTOR(op_type, op_type##_InferShapeFunctor, \ + PT_INFER_META(phi::CompareInferMeta)); \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \ + ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker, \ + ::paddle::framework::EmptyGradOpMaker, \ + op_type##_InferShapeFunctor); \ REGISTER_COMPARE_OP_VERSION(op_type); REGISTER_COMPARE_OP(less_than, "Out = X < Y"); -REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor, - paddle::operators::GreaterThanFunctor); + REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); -REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor, - paddle::operators::GreaterEqualFunctor); + REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); -REGISTER_COMPARE_KERNEL(greater_than, CPU, - paddle::operators::GreaterThanFunctor, - paddle::operators::LessThanFunctor); + REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); -REGISTER_COMPARE_KERNEL(greater_equal, CPU, - paddle::operators::GreaterEqualFunctor, - paddle::operators::LessEqualFunctor); + REGISTER_COMPARE_OP(equal, "Out = X == Y"); -REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor, - paddle::operators::EqualFunctor); + REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); -REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor, - paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu deleted file mode 100644 index 4b9452d0f6..0000000000 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/controlflow/compare_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -namespace paddle { -namespace operators { - -template -class CompareOpKernel - : public framework::OpKernel { - public: - using InT = typename Functor::ELEM_TYPE; - using OutT = bool; - void Compute(const framework::ExecutionContext& ctx) const override { - auto functor = Functor(); - std::vector ins; - std::vector outs; - const auto& cuda_ctx = - ctx.template device_context(); - - int axis = PackTensorsIntoVector(ctx, &ins, &outs); - paddle::operators::LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, functor); - } -}; - -} // namespace operators -} // namespace paddle - -#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \ - REGISTER_OP_CUDA_KERNEL( \ - op_type, \ - ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, void>); - -REGISTER_CUDA_COMPARE_KERNEL(equal, EqualFunctor) -REGISTER_CUDA_COMPARE_KERNEL(not_equal, NotEqualFunctor) -REGISTER_CUDA_COMPARE_KERNEL(less_than, LessThanFunctor) -REGISTER_CUDA_COMPARE_KERNEL(less_equal, LessEqualFunctor) -REGISTER_CUDA_COMPARE_KERNEL(greater_than, GreaterThanFunctor) -REGISTER_CUDA_COMPARE_KERNEL(greater_equal, GreaterEqualFunctor) -#undef REGISTER_CUDA_COMPARE_KERNEL diff --git a/paddle/fluid/operators/controlflow/compare_op.h b/paddle/fluid/operators/controlflow/compare_op.h deleted file mode 100644 index be017a01ef..0000000000 --- a/paddle/fluid/operators/controlflow/compare_op.h +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/transform.h" - -namespace paddle { -namespace operators { - -#define COMPARE_FUNCTOR(func_name, op) \ - template \ - struct func_name { \ - using ELEM_TYPE = InT; \ - HOSTDEVICE OutT operator()(const InT a, const InT b) const { \ - return static_cast(a op b); \ - } \ - }; - -COMPARE_FUNCTOR(LessThanFunctor, <) -COMPARE_FUNCTOR(LessEqualFunctor, <=) -COMPARE_FUNCTOR(GreaterThanFunctor, >) -COMPARE_FUNCTOR(GreaterEqualFunctor, >=) -#undef COMPARE_FUNCTOR - -template -struct EqualFunctor { - using ELEM_TYPE = InT; - HOSTDEVICE OutT operator()(const InT a, const InT b) const { - if (std::is_floating_point::value) { - // This branch will be optimized while compiling if T is integer. It is - // safe to cast a and b to double. - return static_cast(fabs(static_cast(a - b)) < 1e-8); - } else { - return static_cast(a == b); - } - } -}; - -template -struct NotEqualFunctor { - using ELEM_TYPE = InT; - HOSTDEVICE bool operator()(const InT a, const InT b) const { - return !EqualFunctor()(a, b); - } -}; - -template -class CompareOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; - using Tensor = framework::Tensor; - - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* z = context.Output("Out"); - int axis = context.Attr("axis"); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx(context, x, y, axis, - Functor(), z); - } else { - ElementwiseComputeEx( - context, x, y, axis, InverseFunctor(), z); - } - } -}; - -} // namespace operators -} // namespace paddle - -#define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \ - REGISTER_OP_##dev##_KERNEL(op_type, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, \ - functor, inverse_functor>, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, \ - functor, inverse_functor>, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, \ - functor, inverse_functor>, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, \ - functor, inverse_functor>, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, \ - functor, inverse_functor>, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, \ - functor, inverse_functor>); diff --git a/paddle/fluid/operators/controlflow/compare_op_npu.cc b/paddle/fluid/operators/controlflow/compare_op_npu.cc index 7bc4ca0977..7377d7cf8d 100644 --- a/paddle/fluid/operators/controlflow/compare_op_npu.cc +++ b/paddle/fluid/operators/controlflow/compare_op_npu.cc @@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/controlflow/compare_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" diff --git a/paddle/fluid/operators/controlflow/compare_op_xpu.cc b/paddle/fluid/operators/controlflow/compare_op_xpu.cc index 698bd05161..2de8b4c9ba 100644 --- a/paddle/fluid/operators/controlflow/compare_op_xpu.cc +++ b/paddle/fluid/operators/controlflow/compare_op_xpu.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/controlflow/compare_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc index 65599259e2..1f04875c22 100644 --- a/paddle/fluid/operators/matrix_rank_op.cc +++ b/paddle/fluid/operators/matrix_rank_op.cc @@ -17,6 +17,7 @@ #include #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/svd_helper.h" +#include "paddle/phi/kernels/funcs/compare_functors.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -224,15 +225,15 @@ class MatrixRankCPUKernel : public framework::OpKernel { int axis = -1; if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) { - ElementwiseComputeEx, + ElementwiseComputeEx, platform::CPUDeviceContext, T, int>( context, &eigenvalue_tensor, &tol_tensor, axis, - GreaterThanFunctor(), &compare_result); + phi::funcs::GreaterThanFunctor(), &compare_result); } else { - ElementwiseComputeEx, + ElementwiseComputeEx, platform::CPUDeviceContext, T, int>( context, &eigenvalue_tensor, &tol_tensor, axis, - LessThanFunctor(), &compare_result); + phi::funcs::LessThanFunctor(), &compare_result); } auto dito_int = math::DeviceIndependenceTensorOperations { compare_result.mutable_data(detail::NewAxisDim(dim_out, k), context.GetPlace()); int axis = -1; - ElementwiseComputeEx, + ElementwiseComputeEx, platform::CUDADeviceContext, T, int64_t>( context, &eigenvalue_tensor, &tol_tensor, axis, - GreaterThanFunctor(), &compare_result); + phi::funcs::GreaterThanFunctor(), &compare_result); auto dito_int = math::DeviceIndependenceTensorOperations(context); diff --git a/paddle/fluid/operators/matrix_rank_op.h b/paddle/fluid/operators/matrix_rank_op.h index 80774aa916..93545fd310 100644 --- a/paddle/fluid/operators/matrix_rank_op.h +++ b/paddle/fluid/operators/matrix_rank_op.h @@ -15,7 +15,6 @@ #pragma once #include #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/phi/core/ddim.h" namespace paddle { diff --git a/paddle/fluid/operators/metrics/accuracy_op_npu.cc b/paddle/fluid/operators/metrics/accuracy_op_npu.cc index 63bccc2e6e..e83278f88b 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_npu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_npu.cc @@ -12,7 +12,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/controlflow/compare_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/metrics/accuracy_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" diff --git a/paddle/fluid/operators/viterbi_decode_op.h b/paddle/fluid/operators/viterbi_decode_op.h index 0974177e6c..e7fe743b96 100644 --- a/paddle/fluid/operators/viterbi_decode_op.h +++ b/paddle/fluid/operators/viterbi_decode_op.h @@ -14,12 +14,13 @@ limitations under the License. */ #include #include #include -#include "paddle/fluid/operators/controlflow/compare_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/unique_op.h" +#include "paddle/phi/kernels/funcs/compare_functors.h" #include "paddle/phi/kernels/funcs/gather.h" #ifdef PADDLE_WITH_MKLML #include @@ -353,8 +354,8 @@ class ViterbiDecodeKernel : public framework::OpKernel { BinaryOperation SubInt; if (include_bos_eos_tag) { AddFloat(dev_ctx, logit0, start_trans, &alpha); - GetMask()(ctx, left_length, one, - &float_mask); + GetMask()(ctx, left_length, + one, &float_mask); MulFloat(dev_ctx, stop_trans, float_mask, &alpha_nxt); AddFloat(dev_ctx, alpha, alpha_nxt, &alpha); } else { @@ -375,8 +376,8 @@ class ViterbiDecodeKernel : public framework::OpKernel { alpha.Resize({batch_size, n_labels}); // mask = paddle.cast((left_length > 0), dtype='float32') // alpha = mask * alpha_nxt + (1 - mask) * alpha - GetMask()(ctx, left_length, zero, - &float_mask); + GetMask()( + ctx, left_length, zero, &float_mask); // alpha_nxt = mask * alpha_nxt MulFloat(dev_ctx, alpha_nxt, float_mask, &alpha_nxt); // inv_mask = 1 - mask @@ -386,8 +387,8 @@ class ViterbiDecodeKernel : public framework::OpKernel { // alpha += alpha_nxt AddFloat(dev_ctx, alpha, alpha_nxt, &alpha); if (include_bos_eos_tag) { - GetMask()(ctx, left_length, one, - &float_mask); + GetMask()(ctx, left_length, + one, &float_mask); // alpha += mask * trans_exp[:, self.stop_idx] MulFloat(dev_ctx, stop_trans, float_mask, &alpha_nxt); AddFloat(dev_ctx, alpha, alpha_nxt, &alpha); @@ -396,8 +397,8 @@ class ViterbiDecodeKernel : public framework::OpKernel { } argmax(ctx, alpha, &last_ids, scores, 1); left_length.Resize({batch_size}); - GetMask()(ctx, left_length, - zero, &int_mask); + GetMask()( + ctx, left_length, zero, &int_mask); // last_ids_update = last_ids * tag_mask int last_ids_index = 1; int actual_len = (std::min)(seq_len, static_cast(max_seq_len)); @@ -416,17 +417,17 @@ class ViterbiDecodeKernel : public framework::OpKernel { batch_path[actual_len - last_ids_index]; hist->Resize({batch_size * n_labels}); gather(dev_ctx, *hist, gather_idx, &last_ids_update); - GetMask()(ctx, left_length, - zero, &int_mask); + GetMask()( + ctx, left_length, zero, &int_mask); MulInt(dev_ctx, last_ids_update, int_mask, &last_ids_update); - GetMask()(ctx, left_length, zero, - &zero_len_mask); + GetMask()( + ctx, left_length, zero, &zero_len_mask); MulInt(dev_ctx, last_ids, zero_len_mask, &last_ids_tmp); SubInt(dev_ctx, one, zero_len_mask, &zero_len_mask); MulInt(dev_ctx, last_ids_update, zero_len_mask, &last_ids_update); AddInt(dev_ctx, last_ids_update, last_ids_tmp, &last_ids_update); - GetMask()(ctx, left_length, zero, - &int_mask); + GetMask()( + ctx, left_length, zero, &int_mask); MulInt(dev_ctx, last_ids, int_mask, &last_ids); AddInt(dev_ctx, last_ids_update, last_ids, &last_ids); } diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 7682f6b3d4..1f6f0b211b 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -13,11 +13,60 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/infermeta/binary.h" + +#include +#include +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { +void CompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + int axis, + MetaTensor* out) { + auto dim_x = x.dims(); + auto dim_y = y.dims(); + + if (dim_x == dim_y) { + out->share_meta(x); + } else { + int max_dim = std::max(dim_x.size(), dim_y.size()); + int axis = std::abs(dim_x.size() - dim_y.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + funcs::GetBroadcastDimsArrays(dim_x, + dim_y, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + + out->set_dims(make_ddim(out_dims_array)); + out->share_lod(x); + } + + out->set_dtype(DataType::BOOL); +} + +void CompareAllInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + auto dim_x = x.dims(); + auto dim_y = y.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + dim_y.size(), + errors::InvalidArgument( + "The size of dim_y should not be greater than dim_x's.")); + out->share_lod(x); + out->set_dims(make_ddim({1})); + out->set_dtype(DataType::BOOL); +} + void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { auto x_dims = x.dims(); auto x_rank = static_cast(x_dims.size()); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 5906e06b29..47745f8ce1 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -29,6 +29,15 @@ namespace phi { // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. +void CompareInferMeta(const MetaTensor& x, + const MetaTensor& y, + int axis, + MetaTensor* out); + +void CompareAllInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void MatmulInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/compare_kernel.h b/paddle/phi/kernels/compare_kernel.h new file mode 100644 index 0000000000..5b6b8cd868 --- /dev/null +++ b/paddle/phi/kernels/compare_kernel.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +#define DECALRE_COMPARE_KERNEL(compare_kernel) \ + template \ + void compare_kernel(const Context& ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out); + +DECALRE_COMPARE_KERNEL(LessThanKernel) +DECALRE_COMPARE_KERNEL(LessEqualKernel) +DECALRE_COMPARE_KERNEL(GreaterThanKernel) +DECALRE_COMPARE_KERNEL(GreaterEqualKernel) +DECALRE_COMPARE_KERNEL(EqualKernel) +DECALRE_COMPARE_KERNEL(NotEqualKernel) +#undef DECALRE_COMPARE_KERNEL + +#define DECALRE_COMPARE_ALL_KERNEL(compare_all_kernel) \ + template \ + void compare_all_kernel(const Context& ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out); + +DECALRE_COMPARE_ALL_KERNEL(EqualAll) +#undef DECALRE_COMPARE_KERNEL + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc new file mode 100644 index 0000000000..9006325a52 --- /dev/null +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -0,0 +1,143 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/impl/compare_kernel_impl.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" + +namespace phi { + +template +inline void CompareKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + ctx.template Alloc(out); + if (x.dims().size() >= y.dims().size()) { + funcs::ElementwiseCompute( + ctx, x, y, axis, Functor(), out); + } else { + funcs::ElementwiseCompute( + ctx, x, y, axis, InverseFunctor(), out); + } +} + +template +inline void CompareAllKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + bool* out_data = ctx.template Alloc(out); + + if (x.dims() != y.dims()) { + out_data[0] = false; + } else { + DenseTensor tmp; + tmp.Resize(x.dims()); + ctx.template Alloc(&tmp); + + if (x.numel() == 1 && y.numel() == 1) { + bool* tmp_data = tmp.data(); + tmp_data[0] = Functor()(x.data()[0], y.data()[0]); + } else { + funcs::ElementwiseCompute( + ctx, x, y, 0, Functor(), &tmp); + } + auto tmp_flat = EigenVector::Flatten(tmp); + auto out_es = EigenScalar::From(*out); + auto& place = *ctx.eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + out_es.device(place) = tmp_flat.all(reduce_dim); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(less_than, + CPU, + ALL_LAYOUT, + phi::LessThanKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(less_equal, + CPU, + ALL_LAYOUT, + phi::LessEqualKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(greater_than, + CPU, + ALL_LAYOUT, + phi::GreaterThanKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(greater_equal, + CPU, + ALL_LAYOUT, + phi::GreaterEqualKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(equal, + CPU, + ALL_LAYOUT, + phi::EqualKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(not_equal, + CPU, + ALL_LAYOUT, + phi::NotEqualKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} + +PD_REGISTER_KERNEL(equal_all, + CPU, + ALL_LAYOUT, + phi::EqualAllKernel, + bool, + int, + int64_t, + float, + double) {} diff --git a/paddle/phi/kernels/funcs/compare_functors.h b/paddle/phi/kernels/funcs/compare_functors.h new file mode 100644 index 0000000000..569fed7b7f --- /dev/null +++ b/paddle/phi/kernels/funcs/compare_functors.h @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace funcs { + +#define COMPARE_FUNCTOR(func_name, op) \ + template \ + struct func_name { \ + HOSTDEVICE OutT operator()(const InT a, const InT b) const { \ + return static_cast(a op b); \ + } \ + }; + +COMPARE_FUNCTOR(LessThanFunctor, <) +COMPARE_FUNCTOR(LessEqualFunctor, <=) +COMPARE_FUNCTOR(GreaterThanFunctor, >) +COMPARE_FUNCTOR(GreaterEqualFunctor, >=) +#undef COMPARE_FUNCTOR + +template +struct EqualFunctor { + HOSTDEVICE OutT operator()(const InT a, const InT b) const { + if (std::is_floating_point::value) { + return static_cast(fabs(static_cast(a - b)) < 1e-8); + } else { + return static_cast(a == b); + } + } +}; + +template +struct NotEqualFunctor { + HOSTDEVICE bool operator()(const InT a, const InT b) const { + return !EqualFunctor()(a, b); + } +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/compare_kernel.cu b/paddle/phi/kernels/gpu/compare_kernel.cu new file mode 100644 index 0000000000..272448504a --- /dev/null +++ b/paddle/phi/kernels/gpu/compare_kernel.cu @@ -0,0 +1,158 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/compare_kernel.h" +#include "paddle/phi/kernels/impl/compare_kernel_impl.h" + +#include +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/gpu/reduce.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +namespace phi { + +template +struct BitwiseAdd { + // Bitwise add operator, returns a + b + inline T initial() { return static_cast(true); } + + __host__ __device__ __forceinline__ T operator()(const T& a, + const T& b) const { + return a & b; + } +}; + +template +inline void CompareKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + ctx.template Alloc(out); + std::vector ins{&x, &y}; + std::vector outs{out}; + funcs::BroadcastKernel( + ctx, ins, &outs, axis, Functor()); +} + +template +inline void CompareAllKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + bool* out_data = ctx.template Alloc(out); + + if (x.dims() != y.dims()) { + thrust::device_ptr out_dev_ptr(out_data); + thrust::fill(out_dev_ptr, out_dev_ptr + 1, false); + return; + } + + DenseTensor tmp; + tmp.Resize(x.dims()); + ctx.template Alloc(&tmp); + + std::vector ins{&x, &y}; + std::vector outs{&tmp}; + funcs::ElementwiseKernel(ctx, ins, &outs, Functor()); + + // Reduce by 'bitwise and' operator + std::vector reduce_dims; + reduce_dims.resize(tmp.dims().size()); + for (int i = 0; i < reduce_dims.size(); ++i) { + reduce_dims[i] = i; + } + kernels::TensorReduceImpl>( + ctx, tmp, out, kps::IdentityFunctor(), reduce_dims, ctx.stream()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(less_than, + GPU, + ALL_LAYOUT, + phi::LessThanKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(less_equal, + GPU, + ALL_LAYOUT, + phi::LessEqualKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(greater_than, + GPU, + ALL_LAYOUT, + phi::GreaterThanKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(greater_equal, + GPU, + ALL_LAYOUT, + phi::GreaterEqualKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(equal, + GPU, + ALL_LAYOUT, + phi::EqualKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} +PD_REGISTER_KERNEL(not_equal, + GPU, + ALL_LAYOUT, + phi::NotEqualKernel, + bool, + int16_t, + int, + int64_t, + float, + double) {} + +PD_REGISTER_KERNEL(equal_all, + GPU, + ALL_LAYOUT, + phi::EqualAllKernel, + bool, + int, + int64_t, + float, + double) {} diff --git a/paddle/phi/kernels/impl/compare_kernel_impl.h b/paddle/phi/kernels/impl/compare_kernel_impl.h new file mode 100644 index 0000000000..4390c1f8e6 --- /dev/null +++ b/paddle/phi/kernels/impl/compare_kernel_impl.h @@ -0,0 +1,81 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/compare_kernel.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/compare_functors.h" + +namespace phi { + +template +inline void CompareKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + +template +inline void CompareAllKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +#define DEFINE_COMPARE_KERNEL(compare_kernel, functor, inverse_functor) \ + template \ + void compare_kernel(const Context& ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out) { \ + CompareKernelImpl, inverse_functor>( \ + ctx, x, y, axis, out); \ + } + +DEFINE_COMPARE_KERNEL(LessThanKernel, + funcs::LessThanFunctor, + funcs::GreaterThanFunctor) +DEFINE_COMPARE_KERNEL(LessEqualKernel, + funcs::LessEqualFunctor, + funcs::GreaterEqualFunctor) +DEFINE_COMPARE_KERNEL(GreaterThanKernel, + funcs::GreaterThanFunctor, + funcs::LessThanFunctor) +DEFINE_COMPARE_KERNEL(GreaterEqualKernel, + funcs::GreaterEqualFunctor, + funcs::LessEqualFunctor) +DEFINE_COMPARE_KERNEL(EqualKernel, funcs::EqualFunctor, funcs::EqualFunctor) +DEFINE_COMPARE_KERNEL(NotEqualKernel, + funcs::NotEqualFunctor, + funcs::NotEqualFunctor) +#undef DEFINE_COMPARE_KERNEL + +#define DEFINE_COMPARE_ALL_KERNEL(compare_all_kernel, functor) \ + template \ + void compare_all_kernel(const Context& ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + CompareAllKernelImpl>(ctx, x, y, out); \ + } + +DEFINE_COMPARE_ALL_KERNEL(EqualAllKernel, funcs::EqualFunctor) +#undef DEFINE_COMPARE_ALL_KERNEL + +} // namespace phi diff --git a/paddle/phi/ops/compat/compare_sig.cc b/paddle/phi/ops/compat/compare_sig.cc new file mode 100644 index 0000000000..964c7be3db --- /dev/null +++ b/paddle/phi/ops/compat/compare_sig.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LessThanArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("less_than", {"X", "Y"}, {"axis"}, {"Out"}); +} + +KernelSignature LessEqualArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("less_equal", {"X", "Y"}, {"axis"}, {"Out"}); +} + +KernelSignature GreaterThanArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("greater_than", {"X", "Y"}, {"axis"}, {"Out"}); +} + +KernelSignature GreaterEqualArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("greater_equal", {"X", "Y"}, {"axis"}, {"Out"}); +} + +KernelSignature EqualArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("equal", {"X", "Y"}, {"axis"}, {"Out"}); +} + +KernelSignature NotEqualArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("not_equal", {"X", "Y"}, {"axis"}, {"Out"}); +} + +KernelSignature EqualAllArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("equal_all", {"X", "Y"}, {}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(less_than, phi::LessThanArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(less_equal, phi::LessEqualArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(greater_than, phi::GreaterThanArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(greater_equal, phi::GreaterEqualArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(equal, phi::EqualArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(not_equal, phi::NotEqualArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(equal_all, phi::EqualAllArgumentMapping); -- GitLab