未验证 提交 e496640b 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix bmm enforce equal batch (#27694)

上级 742cbe66
...@@ -79,8 +79,10 @@ class TestBmmAPIError(unittest.TestCase): ...@@ -79,8 +79,10 @@ class TestBmmAPIError(unittest.TestCase):
y_data = np.arange(16, dtype='float32').reshape((2, 4, 2)) y_data = np.arange(16, dtype='float32').reshape((2, 4, 2))
y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4)) y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4))
y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2)) y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2))
y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 4, 2))
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1)
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2) self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2)
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -848,6 +848,10 @@ def bmm(x, y, name=None): ...@@ -848,6 +848,10 @@ def bmm(x, y, name=None):
raise ValueError( raise ValueError(
"x's width must be equal with y's height. But received x's shape: {}, y's shape: {}". "x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".
format(x_shape, y_shape)) format(x_shape, y_shape))
if x_shape[0] != y_shape[0]:
raise ValueError(
"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".
format(x_shape, y_shape))
helper = LayerHelper('bmm', **locals()) helper = LayerHelper('bmm', **locals())
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.bmm(x, y) return core.ops.bmm(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册