diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index 3cad86d96c26a0e25fcbaeb02405315895744e50..bf047de86fc21a4d5d9e9ff8f20c9a1982eb25af 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -23,29 +23,6 @@ limitations under the License. */ namespace paddle { namespace operators { -template -class CompareOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; - using Tensor = framework::Tensor; - - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* z = context.Output("Out"); - int axis = context.Attr("axis"); - - if (x->numel() == 1 && y->numel() == 1) { - bool* z_data = z->mutable_data(context.GetPlace()); - z_data[0] = Functor()(x->data()[0], y->data()[0]); - } else { - ElementwiseComputeEx( - context, x, y, axis, Functor(), z); - } - } -}; - template class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: @@ -153,16 +130,22 @@ class CompareOp : public framework::OperatorWithKernel { REGISTER_COMPARE_OP_VERSION(op_type); REGISTER_COMPARE_OP(less_than, "Out = X < Y"); -REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); +REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor, + paddle::operators::GreaterEqualFunctor); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); -REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); +REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor, + paddle::operators::GreaterThanFunctor); REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); REGISTER_COMPARE_KERNEL(greater_than, CPU, - paddle::operators::GreaterThanFunctor); + paddle::operators::GreaterThanFunctor, + paddle::operators::LessEqualFunctor); REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); REGISTER_COMPARE_KERNEL(greater_equal, CPU, - paddle::operators::GreaterEqualFunctor); + paddle::operators::GreaterEqualFunctor, + paddle::operators::LessThanFunctor); REGISTER_COMPARE_OP(equal, "Out = X == Y"); -REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); +REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor, + paddle::operators::EqualFunctor); REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); -REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor); +REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor, + paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index b1f306358359764b919f9e570cf44f9733a7d178..3ca700e16e6e7bcf4136ca68dd895593a63824ec 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -14,11 +14,17 @@ limitations under the License. */ #include "paddle/fluid/operators/controlflow/compare_op.h" -REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); -REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); -REGISTER_COMPARE_KERNEL(greater_than, CUDA, +REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor, + paddle::operators::GreaterEqualFunctor); +REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor, paddle::operators::GreaterThanFunctor); +REGISTER_COMPARE_KERNEL(greater_than, CUDA, + paddle::operators::GreaterThanFunctor, + paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(greater_equal, CUDA, - paddle::operators::GreaterEqualFunctor); -REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); -REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); + paddle::operators::GreaterEqualFunctor, + paddle::operators::LessThanFunctor); +REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor, + paddle::operators::EqualFunctor); +REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor, + paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_op.h b/paddle/fluid/operators/controlflow/compare_op.h index b7529e4ae632d31524846d9d5aa4b1883f4509a1..ff929ee7dfce79536a9ce7c8ae6878fb7e3871e9 100644 --- a/paddle/fluid/operators/controlflow/compare_op.h +++ b/paddle/fluid/operators/controlflow/compare_op.h @@ -68,7 +68,7 @@ struct NotEqualFunctor { } }; -template +template class CompareOpKernel : public framework::OpKernel { public: @@ -80,21 +80,33 @@ class CompareOpKernel auto* y = context.Input("Y"); auto* z = context.Output("Out"); int axis = context.Attr("axis"); - ElementwiseComputeEx(context, x, y, axis, - Functor(), z); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseComputeEx(context, x, y, axis, + Functor(), z); + } else { + ElementwiseComputeEx( + context, x, y, axis, InverseFunctor(), z); + } } }; } // namespace operators } // namespace paddle -#define REGISTER_COMPARE_KERNEL(op_type, dev, functor) \ - REGISTER_OP_##dev##_KERNEL( \ - op_type, ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, functor>, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, functor>, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, functor>, \ - ::paddle::operators::CompareOpKernel< \ - ::paddle::platform::dev##DeviceContext, functor>); +#define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \ + REGISTER_OP_##dev##_KERNEL(op_type, \ + ::paddle::operators::CompareOpKernel< \ + ::paddle::platform::dev##DeviceContext, \ + functor, inverse_functor>, \ + ::paddle::operators::CompareOpKernel< \ + ::paddle::platform::dev##DeviceContext, \ + functor, inverse_functor>, \ + ::paddle::operators::CompareOpKernel< \ + ::paddle::platform::dev##DeviceContext, \ + functor, inverse_functor>, \ + ::paddle::operators::CompareOpKernel< \ + ::paddle::platform::dev##DeviceContext, \ + functor, inverse_functor>); diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index fbf7384b86bc1c844dee09c5b439a523026044e5..8dc80c893126925ff0643b4cde622fe504c6b1d9 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -122,6 +122,23 @@ def create_paddle_case(op_type, callback): fetch_list=[out]) self.assertEqual((res == real_result).all(), True) + def test_broadcast_api_2(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[1, 2, 3], dtype='int32') + y = paddle.static.data( + name='y', shape=[1, 2, 1, 3], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(0, 6).reshape((1, 2, 3)).astype(np.int32) + input_y = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.int32) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + def test_attr_name(self): paddle.enable_static() with program_guard(Program(), Program()):