未验证 提交 c3a61349 编写于 作者: A Abhinav Arora 提交者: GitHub

Adding greater than and less than equal ops to compare op (#5609)

* Adding greater than and less than equal ops to compare op
* Changing the name of the less_than_equal and greater_than_equal op
* Also changing the name of the functors
上级 562599ec
...@@ -94,5 +94,13 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -94,5 +94,13 @@ class CompareOp : public framework::OperatorWithKernel {
REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(greater_than, "Out = X > Y");
REGISTER_LOGICAL_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y");
REGISTER_LOGICAL_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
...@@ -15,4 +15,9 @@ ...@@ -15,4 +15,9 @@
#include "paddle/operators/compare_op.h" #include "paddle/operators/compare_op.h"
REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, GPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(greater_than, GPU,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_KERNEL(greater_equal, GPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor);
...@@ -27,6 +27,24 @@ struct LessThanFunctor { ...@@ -27,6 +27,24 @@ struct LessThanFunctor {
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; } HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; }
}; };
template <typename T>
struct LessEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
};
template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
};
template <typename T> template <typename T>
struct EqualFunctor { struct EqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
......
...@@ -23,6 +23,9 @@ def create_test_class(op_type, typename, callback): ...@@ -23,6 +23,9 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}: for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册