未验证 提交 dc1901f4 编写于 作者: S songyouwei 提交者: GitHub

API(BilinearTensorProduct) error message enhancement (#23528)

* err msg enhance for BilinearTensorProduct
test=develop

* rebase dev
test=develop

* add ut
test=develop
上级 5b3dd806
...@@ -2371,6 +2371,10 @@ class BilinearTensorProduct(layers.Layer): ...@@ -2371,6 +2371,10 @@ class BilinearTensorProduct(layers.Layer):
is_bias=True) is_bias=True)
def forward(self, x, y): def forward(self, x, y):
check_variable_and_dtype(x, 'x', ['float32', 'float64'],
'BilinearTensorProduct')
check_variable_and_dtype(y, 'y', ['float32', 'float64'],
'BilinearTensorProduct')
self._inputs = {"X": x, "Y": y, "Weight": self.weight} self._inputs = {"X": x, "Y": y, "Weight": self.weight}
if self.bias is not None: if self.bias is not None:
self._inputs["Bias"] = self.bias self._inputs["Bias"] = self.bias
......
...@@ -16,9 +16,25 @@ from __future__ import print_function ...@@ -16,9 +16,25 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
class TestDygraphBilinearTensorProductAPIError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
layer = fluid.dygraph.nn.BilinearTensorProduct(
input1_dim=5, input2_dim=4, output_dim=1000)
# the input must be Variable.
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
self.assertRaises(TypeError, layer, x0)
# the input dtype must be float32 or float64
x1 = fluid.data(name='x1', shape=[-1, 5], dtype="float16")
x2 = fluid.data(name='x2', shape=[-1, 4], dtype="float32")
self.assertRaises(TypeError, layer, x1, x2)
class TestBilinearTensorProductOp(OpTest): class TestBilinearTensorProductOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "bilinear_tensor_product" self.op_type = "bilinear_tensor_product"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册