From e496640bf40bd68ad9feb991320d84e05c9e677a Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Wed, 30 Sep 2020 14:31:23 +0800 Subject: [PATCH] fix bmm enforce equal batch (#27694) --- python/paddle/fluid/tests/unittests/test_bmm_op.py | 2 ++ python/paddle/tensor/linalg.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index cb1b3ded534..a1c82668420 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -79,8 +79,10 @@ class TestBmmAPIError(unittest.TestCase): y_data = np.arange(16, dtype='float32').reshape((2, 4, 2)) y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4)) y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2)) + y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 4, 2)) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2) + self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3) if __name__ == "__main__": diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c41c9226d16..2dcdf1603a7 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -848,6 +848,10 @@ def bmm(x, y, name=None): raise ValueError( "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}". format(x_shape, y_shape)) + if x_shape[0] != y_shape[0]: + raise ValueError( + "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}". + format(x_shape, y_shape)) helper = LayerHelper('bmm', **locals()) if in_dygraph_mode(): return core.ops.bmm(x, y) -- GitLab