// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/impl/compare_kernel_impl.h" #ifdef PADDLE_WITH_XPU_KP #include "paddle/phi/backends/xpu/xpu_context.h" #else #include #include #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/compare_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" #endif namespace phi { template struct BitwiseAdd { // Bitwise add operator, returns a + b inline T initial() { return static_cast(true); } __host__ __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a & b; } }; template inline void CompareKernelImpl(const Context& ctx, const DenseTensor& x, const DenseTensor& y, int axis, DenseTensor* out) { ctx.template Alloc(out); std::vector ins{&x, &y}; std::vector outs{out}; funcs::BroadcastKernel(ctx, ins, &outs, Functor(), axis); } #ifndef PADDLE_WITH_XPU_KP template inline void CompareAllKernelImpl(const Context& ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { bool* out_data = ctx.template Alloc(out); if (x.dims() != y.dims()) { thrust::device_ptr out_dev_ptr(out_data); thrust::fill(out_dev_ptr, out_dev_ptr + 1, false); return; } DenseTensor tmp; tmp.Resize(x.dims()); ctx.template Alloc(&tmp); std::vector ins{&x, &y}; std::vector outs{&tmp}; funcs::ElementwiseKernel(ctx, ins, &outs, Functor()); // Reduce by 'bitwise and' operator std::vector reduce_dims; reduce_dims.resize(tmp.dims().size()); for (int i = 0; i < reduce_dims.size(); ++i) { reduce_dims[i] = i; } funcs::ReduceKernel>( ctx, tmp, out, kps::IdentityFunctor(), reduce_dims); } #endif } // namespace phi #ifdef PADDLE_WITH_XPU_KP PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL( greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL( less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL( less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL( greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL( greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } PD_REGISTER_KERNEL( not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } #else PD_REGISTER_KERNEL(equal_all, KPS, ALL_LAYOUT, phi::EqualAllKernel, bool, int, int64_t, float, double) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } #define PD_REGISTER_COMPARE_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, \ KPS, \ ALL_LAYOUT, \ phi::func##Kernel, \ bool, \ int16_t, \ int, \ int64_t, \ float, \ double, \ phi::dtype::float16, \ phi::dtype::bfloat16) { \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } \ PD_REGISTER_KERNEL(name##_raw, \ KPS, \ ALL_LAYOUT, \ phi::func##RawKernel, \ bool, \ int16_t, \ int, \ int64_t, \ float, \ double, \ phi::dtype::float16, \ phi::dtype::bfloat16) { \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan) PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual) PD_REGISTER_COMPARE_KERNEL(equal, Equal) PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual) #endif