提交 986922eb 编写于 作者: W wangshuide2020

check the legal of dims in the entrance of interface and make it more rigorous.

上级 710ee695
...@@ -42,8 +42,7 @@ def convert_array_from_str(dims, limit=0): ...@@ -42,8 +42,7 @@ def convert_array_from_str(dims, limit=0):
Raises: Raises:
ParamValueError, If flexible dimensions exceed limit value. ParamValueError, If flexible dimensions exceed limit value.
""" """
dims = dims.replace('[', '') \ dims = dims.strip().lstrip('[').rstrip(']')
.replace(']', '')
dims_list = [] dims_list = []
count = 0 count = 0
for dim in dims.split(','): for dim in dims.split(','):
...@@ -131,6 +130,23 @@ class TensorProcessor(BaseProcessor): ...@@ -131,6 +130,23 @@ class TensorProcessor(BaseProcessor):
UrlDecodeError, If unquote train id error with strict mode. UrlDecodeError, If unquote train id error with strict mode.
""" """
Validation.check_param_empty(train_id=train_ids, tag=tags) Validation.check_param_empty(train_id=train_ids, tag=tags)
if dims is not None:
if not isinstance(dims, str):
raise ParamValueError('The type of dims must be str, but got {}.'.format(type(dims)))
dims = dims.strip()
if not (dims.startswith('[') and dims.endswith(']')):
raise ParamValueError('The value: {} of dims must be '
'start with `[` and end with `]`.'.format(dims))
for dim in dims[1:-1].split(','):
dim = dim.strip()
if dim == ":":
continue
if dim.startswith('-'):
dim = dim[1:]
if not dim.isdigit():
raise ParamValueError('The value: {} of dims in the square brackets '
'must be int or `:`.'.format(dims))
for index, train_id in enumerate(train_ids): for index, train_id in enumerate(train_ids):
try: try:
train_id = unquote(train_id, errors='strict') train_id = unquote(train_id, errors='strict')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册