提交 cf797d3b 编写于 作者: B buxue

add arg check for enumerate

上级 4670b58c
......@@ -108,7 +108,8 @@ def enumerate_(x, start=0):
"""Enumerate list or tuple."""
x_type = F.typeof(x)
ret = ()
if check_is_tuple_or_list(x_type, "enumerate"):
op_name = "enumerate"
if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
ret = zip(range(start, start + len(x)), x)
return ret
......@@ -123,11 +124,22 @@ def while_cond(x):
@constexpr
def check_is_tuple_or_list(x, op_name):
def check_is_tuple_or_list(x, op_name, arg_name):
"""check whether x is list or tuple."""
if isinstance(x, (mstype.list_type, mstype.tuple_type)):
return True
raise TypeError(f"For '{op_name}', the input parameter should be tuple or list, but got {x}.")
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.")
@constexpr
def check_is_const_int(x, op_name, arg_name):
"""check whether x is const int."""
if x is None:
raise ValueError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
if not isinstance(x, int):
raise ValueError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
return True
@constexpr
def check_is_tensor_bool_cond(shp):
......
......@@ -91,6 +91,7 @@ def test_enumerate_tuple_parameter():
index_sum += i
ret += (j,)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
net(x, x, x)
......@@ -127,10 +128,12 @@ def test_enumerate_tuple_parameter_1():
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
net(x, x, x)
def test_enumerate_tuple_const_2():
class Net(nn.Cell):
def __init__(self):
......@@ -162,20 +165,37 @@ def test_enumerate_tuple_parameter_2():
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
net(x, x, x)
def test_enumerate_parameter_type_error():
def test_enumerate_first_input_type_error():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
return enumerate(x)
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
with pytest.raises(TypeError) as ex:
net(x)
assert "For 'enumerate', the input parameter should be tuple or list" in str(ex.value)
assert "For 'enumerate', the 'first input'" in str(ex.value)
def test_enumerate_start_type_error():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
return enumerate(x, start=1.2)
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
net = Net()
with pytest.raises(ValueError) as ex:
net((x, x))
assert "For 'enumerate', the 'start'" in str(ex.value)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册