提交 6f78cb99 编写于 作者: Q qiaolongfei

add not_equal

上级 23ba79b1
......@@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_OP(not_equal, "Out = X != Y");
REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
......@@ -17,3 +17,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
......@@ -48,6 +48,14 @@ struct EqualFunctor {
}
};
template <typename T>
struct NotEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return !EqualFunctor<T>()(a, b);
}
};
template <typename DeviceContext, typename Functor>
class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
......
......@@ -154,8 +154,9 @@ def monkey_patch_variable():
("__rpow__", "elementwise_pow", True),
# for logical compare
("__eq__", "equal", False),
("__ne__", "not_equal", False),
("__lt__", "less_than", False),
("__le__", "less_equal", False), ):
("__le__", "less_equal", False)):
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse))
......
......@@ -53,6 +53,7 @@ class TestPythonOperatorOverride(unittest.TestCase):
lambda _a, _b: _a > _b,
lambda _a, _b: _a <= _b,
lambda _a, _b: _a >= _b,
lambda _a, _b: _a != _b,
]
# places to check
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册