// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "lite/kernels/arm/compare_compute.h" #include #include "lite/api/paddle_place.h" #include "lite/backends/arm/math/funcs.h" #include "lite/core/op_registry.h" #include "lite/core/type_system.h" namespace paddle { namespace lite { namespace kernels { namespace arm { #define COMPARE_FUNCTOR(name, op) \ template \ struct _##name##Functor { \ inline bool operator()(const T &a, const T &b) const { return a op b; } \ }; COMPARE_FUNCTOR(Equal, ==); COMPARE_FUNCTOR(NotEqual, !=); COMPARE_FUNCTOR(LessThan, <); COMPARE_FUNCTOR(LessEqual, <=); COMPARE_FUNCTOR(GreaterThan, >); COMPARE_FUNCTOR(GreaterEqual, >=); template <> struct _EqualFunctor { inline bool operator()(const float &a, const float &b) const { // It is safe to cast a and b to double. return fabs(static_cast(a - b)) < 1e-8; } }; template <> struct _NotEqualFunctor { inline bool operator()(const float &a, const float &b) const { return !_EqualFunctor()(a, b); } }; inline void get_mid_dims(const lite::DDim &x_dims, const lite::DDim &y_dims, const int axis, int *pre, int *n, int *post) { *pre = 1; *n = 1; *post = 1; for (int i = 0; i < axis; ++i) { (*pre) *= x_dims[i]; } for (int i = 0; i < y_dims.size(); ++i) { (*n) *= y_dims[i]; } for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { (*post) *= x_dims[i]; } } template