提交 d4e3495c 编写于 作者: Q qiaolongfei

add larger_than and larger_equal op and kernel

上级 bad01596
...@@ -100,6 +100,11 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y"); ...@@ -100,6 +100,11 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(larger_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor);
REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(larger_equal, CPU,
paddle::operators::LargerEqualFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
......
...@@ -16,5 +16,9 @@ limitations under the License. */ ...@@ -16,5 +16,9 @@ limitations under the License. */
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(larger_than, CUDA,
paddle::operators::LargerThanFunctor);
REGISTER_COMPARE_KERNEL(larger_equal, CUDA,
paddle::operators::LargerEqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
...@@ -34,6 +34,18 @@ struct LessEqualFunctor { ...@@ -34,6 +34,18 @@ struct LessEqualFunctor {
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
}; };
template <typename T>
struct LargerThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
};
template <typename T>
struct LargerEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
};
template <typename T> template <typename T>
struct EqualFunctor { struct EqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册