From e84b2e9b0564b79fd2c4c0f72379b27317343b95 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Mon, 16 Aug 2021 11:08:13 +0800 Subject: [PATCH] Add bcast semantics checks at C++ level to BroadcastTensorsOp (#34874) --- .../fluid/operators/broadcast_tensors_op.cc | 9 ++++ .../unittests/test_broadcast_tensors_op.py | 42 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/paddle/fluid/operators/broadcast_tensors_op.cc b/paddle/fluid/operators/broadcast_tensors_op.cc index 074607e05ea..bd85c0029da 100644 --- a/paddle/fluid/operators/broadcast_tensors_op.cc +++ b/paddle/fluid/operators/broadcast_tensors_op.cc @@ -38,6 +38,7 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel { int target_rank = 0; const auto& input_dims = ctx->GetInputsDim("X"); + // 1. Find Output rank = max(Inputs rank) for (const auto& input_ddim : input_dims) { target_rank = std::max(target_rank, input_ddim.size()); @@ -64,6 +65,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel { dim_size = input_ddim[axis]; } + if (target_dim_size != 1 && dim_size != 1 && + target_dim_size != dim_size) { + PADDLE_THROW(platform::errors::InvalidArgument( + "BroadcastTensorsOp inputs does not satisfy bcast semantics," + "Please check axis = %d in reverse order", + index)); + } + // We performed bcast semantics check at python level // So input tensors should all have legal shape target_dim_size = std::max(target_dim_size, dim_size); diff --git a/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py b/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py index 602c5bae8f8..f60e4067a09 100644 --- a/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py +++ b/python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py @@ -192,5 +192,47 @@ class TestRaiseBroadcastTensorsError(unittest.TestCase): self.assertRaises(TypeError, test_bcast_semantics) +class TestRaiseBroadcastTensorsErrorDyGraph(unittest.TestCase): + def test_errors(self): + def test_type(): + inputs = [ + paddle.to_tensor( + np.ones( + shape=[1, 1, 1, 1], dtype='float32', name="x4")), + paddle.to_tensor( + np.ones( + shape=[1, 4, 1, 1], dtype='float64', name="x5")) + ] + paddle.broadcast_tensors(inputs) + + def test_dtype(): + inputs = [ + paddle.to_tensor( + np.ones( + shape=[1, 1, 1, 1], dtype='int8', name="x6")), + paddle.to_tensor( + np.ones( + shape=[1, 4, 1, 1], dtype='int8', name="x7")) + ] + paddle.broadcast_tensors(inputs) + + def test_bcast_semantics(): + inputs = [ + paddle.to_tensor( + np.ones( + shape=[1, 3, 1, 1], dtype='float32', name="x9")), + paddle.to_tensor( + np.ones( + shape=[1, 8, 1, 1], dtype='float32', name="x10")) + ] + paddle.broadcast_tensors(inputs) + + paddle.disable_static() + self.assertRaises(TypeError, test_type) + self.assertRaises(TypeError, test_dtype) + self.assertRaises(TypeError, test_bcast_semantics) + paddle.enable_static() + + if __name__ == '__main__': unittest.main() -- GitLab