未验证 提交 8fb2dce9 编写于 作者: L LoneRanger 提交者: GitHub

Fix Python IndexError of Case12: paddle.static.nn.bilinear_tensor_product (#50008)

上级 b8e6ca92
...@@ -34,6 +34,17 @@ class TestDygraphBilinearTensorProductAPIError(unittest.TestCase): ...@@ -34,6 +34,17 @@ class TestDygraphBilinearTensorProductAPIError(unittest.TestCase):
x1 = fluid.data(name='x1', shape=[-1, 5], dtype="float16") x1 = fluid.data(name='x1', shape=[-1, 5], dtype="float16")
x2 = fluid.data(name='x2', shape=[-1, 4], dtype="float32") x2 = fluid.data(name='x2', shape=[-1, 4], dtype="float32")
self.assertRaises(TypeError, layer, x1, x2) self.assertRaises(TypeError, layer, x1, x2)
# the dimensions of x and y must be 2
paddle.enable_static()
x3 = paddle.static.data("", shape=[0], dtype="float32")
x4 = paddle.static.data("", shape=[0], dtype="float32")
self.assertRaises(
ValueError,
paddle.static.nn.bilinear_tensor_product,
x3,
x4,
1000,
)
class TestBilinearTensorProductOp(OpTest): class TestBilinearTensorProductOp(OpTest):
......
...@@ -2574,7 +2574,12 @@ def bilinear_tensor_product( ...@@ -2574,7 +2574,12 @@ def bilinear_tensor_product(
""" """
helper = LayerHelper('bilinear_tensor_product', **locals()) helper = LayerHelper('bilinear_tensor_product', **locals())
dtype = helper.input_dtype('x') dtype = helper.input_dtype('x')
if len(x.shape) != 2 or len(y.shape) != 2:
raise ValueError(
"Input x and y should be 2D tensor, but received x with the shape of {}, y with the shape of {}".format(
x.shape, y.shape
)
)
param_shape = [size, x.shape[1], y.shape[1]] param_shape = [size, x.shape[1], y.shape[1]]
w = helper.create_parameter( w = helper.create_parameter(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册