// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace phi { namespace funcs { #define COMPARE_FUNCTOR(func_name, op) \ template \ struct func_name { \ HOSTDEVICE OutT operator()(const InT a, const InT b) const { \ return static_cast(a op b); \ } \ }; COMPARE_FUNCTOR(LessThanFunctor, <) COMPARE_FUNCTOR(LessEqualFunctor, <=) COMPARE_FUNCTOR(GreaterThanFunctor, >) COMPARE_FUNCTOR(GreaterEqualFunctor, >=) #undef COMPARE_FUNCTOR template struct EqualFunctor { HOSTDEVICE OutT operator()(const InT a, const InT b) const { if (std::is_floating_point::value) { if (isinf(static_cast(a)) || isinf(static_cast(b))) return static_cast(a == b); if (isnan(static_cast(a)) || isnan(static_cast(b))) return static_cast(false); return static_cast(fabs(static_cast(a - b)) < 1e-8); } else { return static_cast(a == b); } } }; template struct NotEqualFunctor { HOSTDEVICE bool operator()(const InT a, const InT b) const { return !EqualFunctor()(a, b); } }; } // namespace funcs } // namespace phi