未验证 提交 d40c5240 编写于 作者: Z zhongpu 提交者: GitHub

error message enhancement for Linear, test=develop (#23595)

上级 c1c8c7e4
......@@ -939,6 +939,10 @@ class Linear(layers.Layer):
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], "Linear")
attrs = {
"x_num_col_dims": len(input.shape) - 1,
"y_num_col_dims": 1,
......
......@@ -155,6 +155,31 @@ class TestLayer(LayerTest):
self.assertTrue(np.array_equal(static_ret, dy_ret_value))
with self.static_graph():
# the input of Linear must be Variable.
def test_Variable():
inp = np.ones([3, 32, 32], dtype='float32')
linear = nn.Linear(
32,
4,
bias_attr=fluid.initializer.ConstantInitializer(value=1))
linear_ret1 = linear(inp)
self.assertRaises(TypeError, test_Variable)
# the input dtype of Linear must be float16 or float32 or float64
# float16 only can be set on GPU place
def test_type():
inp = np.ones([3, 32, 32], dtype='int32')
linear = nn.Linear(
32,
4,
bias_attr=fluid.initializer.ConstantInitializer(value=1))
linear_ret2 = linear(inp)
self.assertRaises(TypeError, test_type)
def test_layer_norm(self):
inp = np.ones([3, 32, 32], dtype='float32')
with self.static_graph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册