未验证 提交 e84b2e9b 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Add bcast semantics checks at C++ level to BroadcastTensorsOp (#34874)

上级 28279f6f
......@@ -38,6 +38,7 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
int target_rank = 0;
const auto& input_dims = ctx->GetInputsDim("X");
// 1. Find Output rank = max(Inputs rank)
for (const auto& input_ddim : input_dims) {
target_rank = std::max(target_rank, input_ddim.size());
......@@ -64,6 +65,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
dim_size = input_ddim[axis];
}
if (target_dim_size != 1 && dim_size != 1 &&
target_dim_size != dim_size) {
PADDLE_THROW(platform::errors::InvalidArgument(
"BroadcastTensorsOp inputs does not satisfy bcast semantics,"
"Please check axis = %d in reverse order",
index));
}
// We performed bcast semantics check at python level
// So input tensors should all have legal shape
target_dim_size = std::max(target_dim_size, dim_size);
......
......@@ -192,5 +192,47 @@ class TestRaiseBroadcastTensorsError(unittest.TestCase):
self.assertRaises(TypeError, test_bcast_semantics)
class TestRaiseBroadcastTensorsErrorDyGraph(unittest.TestCase):
def test_errors(self):
def test_type():
inputs = [
paddle.to_tensor(
np.ones(
shape=[1, 1, 1, 1], dtype='float32', name="x4")),
paddle.to_tensor(
np.ones(
shape=[1, 4, 1, 1], dtype='float64', name="x5"))
]
paddle.broadcast_tensors(inputs)
def test_dtype():
inputs = [
paddle.to_tensor(
np.ones(
shape=[1, 1, 1, 1], dtype='int8', name="x6")),
paddle.to_tensor(
np.ones(
shape=[1, 4, 1, 1], dtype='int8', name="x7"))
]
paddle.broadcast_tensors(inputs)
def test_bcast_semantics():
inputs = [
paddle.to_tensor(
np.ones(
shape=[1, 3, 1, 1], dtype='float32', name="x9")),
paddle.to_tensor(
np.ones(
shape=[1, 8, 1, 1], dtype='float32', name="x10"))
]
paddle.broadcast_tensors(inputs)
paddle.disable_static()
self.assertRaises(TypeError, test_type)
self.assertRaises(TypeError, test_dtype)
self.assertRaises(TypeError, test_bcast_semantics)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册