提交 4670b58c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2908 dataset: avoid same column name and add para check for to_device

Merge pull request !2908 from ms_yan/column_empty
......@@ -95,7 +95,7 @@ def check_uint32(value, arg_name=""):
def check_pos_int32(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [POS_INT_MIN, INT32_MAX])
check_value(value, [POS_INT_MIN, INT32_MAX], arg_name)
def check_uint64(value, arg_name=""):
......@@ -143,6 +143,8 @@ def check_columns(columns, name):
col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
type_check_list(columns, (str,), col_names)
if len(set(columns)) != len(columns):
raise ValueError("Every column name should not be same with others in column_names.")
def parse_user_args(method, *args, **kwargs):
......
......@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_split, check_bucket_batch_by_length, check_cluedataset
check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
......@@ -939,6 +939,7 @@ class Dataset:
raise TypeError("apply_func must return a dataset.")
return dataset
@check_positive_int32
def device_que(self, prefetch_size=None):
"""
Return a transferredDataset that transfer data through device.
......@@ -956,6 +957,7 @@ class Dataset:
"""
return self.to_device()
@check_positive_int32
def to_device(self, num_batch=None):
"""
Transfer data through CPU, GPU or Ascend devices.
......@@ -973,7 +975,7 @@ class Dataset:
Raises:
TypeError: If device_type is empty.
ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
ValueError: If num_batch is None or 0 or larger than int_max.
ValueError: If num_batch is negative or larger than int_max.
RuntimeError: If dataset is unknown.
RuntimeError: If distribution file path is given but failed to read.
"""
......
......@@ -25,7 +25,7 @@ from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
check_columns, check_positive
check_columns, check_positive, check_pos_int32
from . import datasets
from . import samplers
......@@ -593,6 +593,25 @@ def check_take(method):
return new_method
def check_positive_int32(method):
"""check whether the input argument is positive and int, only works for functions with one input."""
@wraps(method)
def new_method(self, *args, **kwargs):
[count], param_dict = parse_user_args(method, *args, **kwargs)
para_name = None
for key in list(param_dict.keys()):
if key not in ['self', 'cls']:
para_name = key
# Need to get default value of param
if count is not None:
check_pos_int32(count, para_name)
return method(self, *args, **kwargs)
return new_method
def check_zip(method):
"""check the input arguments of zip."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册