diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index cb1b3ded53472c022ef83539f573c9e6c192a966..a1c82668420872c67da377579af5e4761e3fe58e 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -79,8 +79,10 @@ class TestBmmAPIError(unittest.TestCase): y_data = np.arange(16, dtype='float32').reshape((2, 4, 2)) y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4)) y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2)) + y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 4, 2)) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2) + self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3) if __name__ == "__main__": diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c41c9226d16b41934f738719ae1251127d439ccf..2dcdf1603a737c3c488879835ea4d41ed5271247 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -848,6 +848,10 @@ def bmm(x, y, name=None): raise ValueError( "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}". format(x_shape, y_shape)) + if x_shape[0] != y_shape[0]: + raise ValueError( + "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}". + format(x_shape, y_shape)) helper = LayerHelper('bmm', **locals()) if in_dygraph_mode(): return core.ops.bmm(x, y)