diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index fcc3e5a2c13722b45f86f0bd3cee595e71f33421..60f29ba39a8ee64f9fe5d95e685cac1fb52dfd21 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/controlflow/compare_op.h" +#include #include +#include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { @@ -85,14 +88,22 @@ class CompareOp : public framework::OperatorWithKernel { auto dim_x = context->GetInputDim("X"); auto dim_y = context->GetInputDim("Y"); - PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(), - platform::errors::InvalidArgument( - "The size of dim_y should not be greater than " - "dim_x's, but received dim_y: %d > dim_x: %d.\n", - dim_y.size(), dim_x.size())); - - context->SetOutputDim("Out", context->GetInputDim("X")); - context->ShareLoD("X", "Out"); + if (context->GetInputDim("X") == context->GetInputDim("Y")) { + context->ShareDim("X", /*->*/ "Out"); + context->ShareLoD("X", /*->*/ "Out"); + } else { + int max_dim = std::max(dim_x.size(), dim_y.size()); + int axis = std::abs(dim_x.size() - dim_y.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(dim_x, dim_y, x_dims_array.data(), + y_dims_array.data(), out_dims_array.data(), + max_dim, axis); + context->SetOutputDim("Out", framework::make_ddim(out_dims_array)); + // to do + context->ShareLoD("X", /*->*/ "Out"); + } } framework::OpKernelType GetExpectedKernelType( diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index ef687ff75c6fd22439aba81a9763b4f177a0f614..a97f54d6cac1ea91f05cb3dc68729f5b68df7c9e 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -72,25 +72,40 @@ def create_paddle_case(op_type, callback): class PaddleCls(unittest.TestCase): def setUp(self): self.op_type = op_type - self.input_x = np.array([1, 2, 3, 4]) - self.input_y = np.array([1, 3, 2, 4]) + self.input_x = np.array([1, 2, 3, 4]).astype(np.int64) + self.input_y = np.array([1, 3, 2, 4]).astype(np.int64) self.real_result = callback(self.input_x, self.input_y) + self.place = fluid.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) def test_api(self): with program_guard(Program(), Program()): - x = fluid.layers.data(name='x', shape=[4], dtype='int64') - y = fluid.layers.data(name='y', shape=[4], dtype='int64') + x = fluid.data(name='x', shape=[4], dtype='int64') + y = fluid.data(name='y', shape=[4], dtype='int64') op = eval("paddle.%s" % (self.op_type)) out = op(x, y) - place = fluid.CPUPlace() - if core.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - exe = fluid.Executor(place) + exe = fluid.Executor(self.place) res, = exe.run(feed={"x": self.input_x, "y": self.input_y}, fetch_list=[out]) self.assertEqual((res == self.real_result).all(), True) + def test_broadcast_api_1(self): + with program_guard(Program(), Program()): + x = paddle.nn.data(name='x', shape=[1, 2, 1, 3], dtype='int32') + y = paddle.nn.data(name='y', shape=[1, 2, 3], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.Executor(self.place) + input_x = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.int32) + input_y = np.arange(0, 6).reshape((1, 2, 3)).astype(np.int32) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + def test_attr_name(self): with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[4], dtype='int32') @@ -104,6 +119,7 @@ def create_paddle_case(op_type, callback): globals()[cls_name] = PaddleCls +create_paddle_case('less_than', lambda _a, _b: _a < _b) create_paddle_case('less_equal', lambda _a, _b: _a <= _b) create_paddle_case('greater_than', lambda _a, _b: _a > _b) create_paddle_case('greater_equal', lambda _a, _b: _a >= _b)