未验证 提交 b2f530e3 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #8532 from jacquesqiao/fix-compare-op

Fix compare op
......@@ -83,7 +83,7 @@ class CompareOp : public framework::OperatorWithKernel {
} // namespace operators
} // namespace paddle
#define REGISTER_LOGICAL_OP(op_type, _equation) \
#define REGISTER_COMPARE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
......@@ -96,11 +96,17 @@ class CompareOp : public framework::OperatorWithKernel {
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_OP(not_equal, "Out = X != Y");
REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
......@@ -14,7 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/compare_op.h"
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
......@@ -34,6 +34,18 @@ struct LessEqualFunctor {
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
};
template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
};
template <typename T>
struct EqualFunctor {
using ELEM_TYPE = T;
......@@ -76,7 +88,7 @@ class CompareOpKernel
} // namespace operators
} // namespace paddle
#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \
#define REGISTER_COMPARE_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int>>, \
......
......@@ -157,7 +157,9 @@ def monkey_patch_variable():
("__eq__", "equal", False),
("__ne__", "not_equal", False),
("__lt__", "less_than", False),
("__le__", "less_equal", False)):
("__le__", "less_equal", False),
("__gt__", "greater_than", False),
("__ge__", "greater_equal", False)):
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse))
......
......@@ -38,7 +38,10 @@ def create_test_class(op_type, typename, callback):
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册