提交 180c1750 编写于 作者: M ms_yan 提交者: 高东海

add parameter check for Class Schema

上级 6dc6d6bc
......@@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_zip, check_rename, \
check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_zip_dataset
check_zip_dataset, check_add_column
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
......@@ -2334,13 +2334,20 @@ class Schema:
self.dataset_type = ''
self.num_rows = 0
else:
if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
raise ValueError("The file %s does not exist or permission denied!" % schema_file)
try:
with open(schema_file, 'r') as load_f:
json_obj = json.load(load_f)
self.from_json(json_obj)
except json.decoder.JSONDecodeError:
raise RuntimeError("Schema file failed to load")
raise RuntimeError("Schema file failed to load.")
except UnicodeDecodeError:
raise RuntimeError("Schema file failed to decode.")
except Exception:
raise RuntimeError("Schema file failed to open.")
self.from_json(json_obj)
@check_add_column
def add_column(self, name, de_type, shape=None):
"""
Add new column to the schema.
......@@ -2359,10 +2366,8 @@ class Schema:
if isinstance(de_type, typing.Type):
de_type = mstype_to_detype(de_type)
new_column["type"] = str(de_type)
elif isinstance(de_type, str):
new_column["type"] = str(DataType(de_type))
else:
raise ValueError("Unknown column type")
new_column["type"] = str(DataType(de_type))
if shape is not None:
new_column["shape"] = shape
......@@ -2391,7 +2396,7 @@ class Schema:
Parse the columns and add it to self.
Args:
columns (list[str]): names of columns.
columns (dict or list[str]): names of columns.
Raises:
RuntimeError: If failed to parse schema file.
......@@ -2399,6 +2404,8 @@ class Schema:
RuntimeError: If column's name field is missing.
RuntimeError: If column's type field is missing.
"""
if columns is None:
raise TypeError("Expected non-empty dict or string list.")
self.columns = []
for col in columns:
name = None
......@@ -2443,6 +2450,8 @@ class Schema:
RuntimeError: if dataset type is missing in the object.
RuntimeError: if columns are missing in the object.
"""
if not isinstance(json_obj, dict) or json_obj is None:
raise ValueError("Expected non-empty dict.")
for k, v in json_obj.items():
if k == "datasetType":
self.dataset_type = v
......
......@@ -19,10 +19,15 @@ import inspect as ins
import os
from functools import wraps
from multiprocessing import cpu_count
from mindspore._c_expression import typing
from . import samplers
from . import datasets
INT32_MAX = 2147483647
valid_detype = [
"bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64", "float16", "float32", "float64"
]
def check(method):
......@@ -188,6 +193,12 @@ def check(method):
return wrapper
def check_valid_detype(type_):
if type_ not in valid_detype:
raise ValueError("Unknown column type")
return True
def check_filename(path):
"""
check the filename in the path
......@@ -743,3 +754,42 @@ def check_project(method):
return method(*args, **kwargs)
return new_method
def check_shape(shape, name):
if isinstance(shape, list):
for element in shape:
if not isinstance(element, int):
raise TypeError(
"Each element in {0} should be of type int. Got {1}.".format(name, type(element)))
else:
raise TypeError("Expected int list.")
def check_add_column(method):
"""check the input arguments of add_column."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check name; required argument
name = param_dict.get("name")
if not isinstance(name, str) or not name:
raise TypeError("Expected non-empty string.")
# check type; required argument
de_type = param_dict.get("de_type")
if de_type is not None:
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
raise ValueError("Unknown column type.")
else:
raise TypeError("Expected non-empty string.")
# check shape
shape = param_dict.get("shape")
if shape is not None:
check_shape(shape, "shape")
return method(*args, **kwargs)
return new_method
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册