未验证 提交 c72ed824 编写于 作者: W wawltor 提交者: GitHub

fix the paddle compare op for the broadcast when the element equal (#32941)

* fix the paddle compare op for the broadcast

* fix compare op in for in the cuda device
上级 c809530e
...@@ -131,18 +131,18 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -131,18 +131,18 @@ class CompareOp : public framework::OperatorWithKernel {
REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor, REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor); paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor, REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor); paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(greater_than, CPU, REGISTER_COMPARE_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor, paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor); paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(greater_equal, CPU, REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor, paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor); paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor, REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor); paddle::operators::EqualFunctor);
......
...@@ -15,15 +15,15 @@ limitations under the License. */ ...@@ -15,15 +15,15 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/fluid/operators/controlflow/compare_op.h"
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor, REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor); paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA, REGISTER_COMPARE_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor, paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor); paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(greater_equal, CUDA, REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor, paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor); paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor, REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor); paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor, REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor,
......
...@@ -139,6 +139,22 @@ def create_paddle_case(op_type, callback): ...@@ -139,6 +139,22 @@ def create_paddle_case(op_type, callback):
fetch_list=[out]) fetch_list=[out])
self.assertEqual((res == real_result).all(), True) self.assertEqual((res == real_result).all(), True)
def test_broadcast_api_3(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[5], dtype='int32')
y = paddle.static.data(name='y', shape=[3, 1], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
input_x = np.arange(0, 5).reshape((5)).astype(np.int32)
input_y = np.array([5, 3, 2]).reshape((3, 1)).astype(np.int32)
real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_attr_name(self): def test_attr_name(self):
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册