未验证 提交 7d4e2158 编写于 作者: L LutaoChu 提交者: GitHub

add parameters check in static mode for diag op

add parameters check in static mode for diag op 
上级 a7db9acc
......@@ -83,6 +83,14 @@ class TestDiagV2Error(unittest.TestCase):
self.assertRaises(TypeError, test_diag_v2_type)
x = paddle.static.data('data', [3, 3])
self.assertRaises(TypeError, paddle.diag, x, offset=2.5)
self.assertRaises(TypeError, paddle.diag, x, padding_value=[9])
x = paddle.static.data('data2', [3, 3, 3])
self.assertRaises(ValueError, paddle.diag, x)
class TestDiagV2API(unittest.TestCase):
def setUp(self):
......
......@@ -978,6 +978,13 @@ def diag(x, offset=0, padding_value=0, name=None):
check_type(x, 'x', (Variable), 'diag_v2')
check_dtype(x.dtype, 'x', ['float32', 'float64', 'int32', 'int64'],
'diag_v2')
check_type(offset, 'offset', (int), 'diag_v2')
check_type(padding_value, 'padding_value', (int, float), 'diag_v2')
if len(x.shape) != 1 and len(x.shape) != 2:
raise ValueError(
"The dimension of input x must be either 1 or 2, but received {}".
format(len(x.shape)))
helper = LayerHelper("diag_v2", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册