提交 28d07e3c 编写于 作者: Q qiaolongfei

add python part of compare op

上级 d4e3495c
...@@ -100,11 +100,12 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y"); ...@@ -100,11 +100,12 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(larger_than, "Out = X > Y"); REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor); REGISTER_COMPARE_KERNEL(greater_than, CPU,
REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y"); paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(larger_equal, CPU, REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
paddle::operators::LargerEqualFunctor); REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
......
...@@ -16,9 +16,9 @@ limitations under the License. */ ...@@ -16,9 +16,9 @@ limitations under the License. */
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(larger_than, CUDA, REGISTER_COMPARE_KERNEL(greater_than, CUDA,
paddle::operators::LargerThanFunctor); paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(larger_equal, CUDA, REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
paddle::operators::LargerEqualFunctor); paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
...@@ -35,13 +35,13 @@ struct LessEqualFunctor { ...@@ -35,13 +35,13 @@ struct LessEqualFunctor {
}; };
template <typename T> template <typename T>
struct LargerThanFunctor { struct GreaterThanFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
}; };
template <typename T> template <typename T>
struct LargerEqualFunctor { struct GreaterEqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
}; };
......
...@@ -157,7 +157,9 @@ def monkey_patch_variable(): ...@@ -157,7 +157,9 @@ def monkey_patch_variable():
("__eq__", "equal", False), ("__eq__", "equal", False),
("__ne__", "not_equal", False), ("__ne__", "not_equal", False),
("__lt__", "less_than", False), ("__lt__", "less_than", False),
("__le__", "less_equal", False)): ("__le__", "less_equal", False),
("__gt__", "greater_than", False),
("__ge__", "greater_equal", False)):
setattr(Variable, method_name, setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse)) _elemwise_method_creator_(method_name, op_type, reverse))
......
...@@ -38,7 +38,10 @@ def create_test_class(op_type, typename, callback): ...@@ -38,7 +38,10 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}: for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册