提交 28d07e3c 编写于 作者: Q qiaolongfei

add python part of compare op

上级 d4e3495c
......@@ -100,11 +100,12 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(larger_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor);
REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(larger_equal, CPU,
paddle::operators::LargerEqualFunctor);
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
......
......@@ -16,9 +16,9 @@ limitations under the License. */
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(larger_than, CUDA,
paddle::operators::LargerThanFunctor);
REGISTER_COMPARE_KERNEL(larger_equal, CUDA,
paddle::operators::LargerEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
......@@ -35,13 +35,13 @@ struct LessEqualFunctor {
};
template <typename T>
struct LargerThanFunctor {
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
};
template <typename T>
struct LargerEqualFunctor {
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
};
......
......@@ -157,7 +157,9 @@ def monkey_patch_variable():
("__eq__", "equal", False),
("__ne__", "not_equal", False),
("__lt__", "less_than", False),
("__le__", "less_equal", False)):
("__le__", "less_equal", False),
("__gt__", "greater_than", False),
("__ge__", "greater_equal", False)):
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse))
......
......@@ -38,7 +38,10 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册