提交 6f78cb99 编写于 作者: Q qiaolongfei

add not_equal

上级 23ba79b1
...@@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); ...@@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_OP(not_equal, "Out = X != Y");
REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
...@@ -17,3 +17,4 @@ limitations under the License. */ ...@@ -17,3 +17,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
...@@ -48,6 +48,14 @@ struct EqualFunctor { ...@@ -48,6 +48,14 @@ struct EqualFunctor {
} }
}; };
template <typename T>
struct NotEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return !EqualFunctor<T>()(a, b);
}
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class CompareOpKernel class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEM_TYPE> {
......
...@@ -154,8 +154,9 @@ def monkey_patch_variable(): ...@@ -154,8 +154,9 @@ def monkey_patch_variable():
("__rpow__", "elementwise_pow", True), ("__rpow__", "elementwise_pow", True),
# for logical compare # for logical compare
("__eq__", "equal", False), ("__eq__", "equal", False),
("__ne__", "not_equal", False),
("__lt__", "less_than", False), ("__lt__", "less_than", False),
("__le__", "less_equal", False), ): ("__le__", "less_equal", False)):
setattr(Variable, method_name, setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse)) _elemwise_method_creator_(method_name, op_type, reverse))
......
...@@ -53,6 +53,7 @@ class TestPythonOperatorOverride(unittest.TestCase): ...@@ -53,6 +53,7 @@ class TestPythonOperatorOverride(unittest.TestCase):
lambda _a, _b: _a > _b, lambda _a, _b: _a > _b,
lambda _a, _b: _a <= _b, lambda _a, _b: _a <= _b,
lambda _a, _b: _a >= _b, lambda _a, _b: _a >= _b,
lambda _a, _b: _a != _b,
] ]
# places to check # places to check
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册