提交 102b205c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2833 Fix engine validators.py

Merge pull request !2833 from nhussain/engine_validators
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
General Validators.
"""
import inspect
from multiprocessing import cpu_count
import os
import numpy as np
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
UINT8_MAX = 255
UINT8_MIN = 0
UINT32_MAX = 4294967295
UINT32_MIN = 0
UINT64_MAX = 18446744073709551615
UINT64_MIN = 0
INT32_MAX = 2147483647
INT32_MIN = -2147483648
INT64_MAX = 9223372036854775807
INT64_MIN = -9223372036854775808
FLOAT_MAX_INTEGER = 16777216
FLOAT_MIN_INTEGER = -16777216
DOUBLE_MAX_INTEGER = 9007199254740992
DOUBLE_MIN_INTEGER = -9007199254740992
valid_detype = [
"bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64", "float16", "float32", "float64", "string"
]
def pad_arg_name(arg_name):
if arg_name != "":
arg_name = arg_name + " "
return arg_name
def check_value(value, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value < valid_range[0] or value > valid_range[1]:
raise ValueError(
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
valid_range[1]))
def check_range(values, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
raise ValueError(
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
valid_range[1]))
def check_positive(value, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value <= 0:
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
def check_positive_float(value, arg_name=""):
arg_name = pad_arg_name(arg_name)
type_check(value, (float,), arg_name)
check_positive(value, arg_name)
def check_2tuple(value, arg_name=""):
if not (isinstance(value, tuple) and len(value) == 2):
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))
def check_uint8(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [UINT8_MIN, UINT8_MAX])
def check_uint32(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [UINT32_MIN, UINT32_MAX])
def check_pos_int32(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [POS_INT_MIN, INT32_MAX])
def check_uint64(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [UINT64_MIN, UINT64_MAX])
def check_pos_int64(value, arg_name=""):
type_check(value, (int,), arg_name)
check_value(value, [UINT64_MIN, INT64_MAX])
def check_pos_float32(value, arg_name=""):
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name)
def check_pos_float64(value, arg_name=""):
check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name)
def check_valid_detype(type_):
if type_ not in valid_detype:
raise ValueError("Unknown column type")
return True
def check_columns(columns, name):
type_check(columns, (list, str), name)
if isinstance(columns, list):
if not columns:
raise ValueError("Column names should not be empty")
col_names = ["col_{0}".format(i) for i in range(len(columns))]
type_check_list(columns, (str,), col_names)
def parse_user_args(method, *args, **kwargs):
"""
Parse user arguments in a function
Args:
method (method): a callable function
*args: user passed args
**kwargs: user passed kwargs
Returns:
user_filled_args (list): values of what the user passed in for the arguments,
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
"""
sig = inspect.signature(method)
if 'self' in sig.parameters or 'cls' in sig.parameters:
ba = sig.bind(method, *args, **kwargs)
ba.apply_defaults()
params = list(sig.parameters.keys())[1:]
else:
ba = sig.bind(*args, **kwargs)
ba.apply_defaults()
params = list(sig.parameters.keys())
user_filled_args = [ba.arguments.get(arg_value) for arg_value in params]
return user_filled_args, ba.arguments
def type_check_list(args, types, arg_names):
"""
Check the type of each parameter in the list
Args:
args (list, tuple): a list or tuple of any variable
types (tuple): tuple of all valid types for arg
arg_names (list, tuple of str): the names of args
Returns:
Exception: when the type is not correct, otherwise nothing
"""
type_check(args, (list, tuple,), arg_names)
if len(args) != len(arg_names):
raise ValueError("List of arguments is not the same length as argument_names.")
for arg, arg_name in zip(args, arg_names):
type_check(arg, types, arg_name)
def type_check(arg, types, arg_name):
"""
Check the type of the parameter
Args:
arg : any variable
types (tuple): tuple of all valid types for arg
arg_name (str): the name of arg
Returns:
Exception: when the type is not correct, otherwise nothing
"""
# handle special case of booleans being a subclass of ints
print_value = '\"\"' if repr(arg) == repr('') else arg
if int in types and bool not in types:
if isinstance(arg, bool):
raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types))
if not isinstance(arg, types):
raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types))
def check_filename(path):
"""
check the filename in the path
Args:
path (str): the path
Returns:
Exception: when error
"""
if not isinstance(path, str):
raise TypeError("path: {} is not string".format(path))
filename = os.path.basename(path)
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
# '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>',
# '*', '(', '%', ')', '-', '=', '{', '?', '$'
forbidden_symbols = set(r'\/:*?"<>|`&\';')
if set(filename) & forbidden_symbols:
raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'")
if filename.startswith(' ') or filename.endswith(' '):
raise ValueError("filename should not start/end with space")
return True
def check_dir(dataset_dir):
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))
def check_file(dataset_file):
check_filename(dataset_file)
if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
raise ValueError("The file {} does not exist or permission denied!".format(dataset_file))
def check_sampler_shuffle_shard_options(param_dict):
"""
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
Args:
param_dict (dict): param_dict
Returns:
Exception: ValueError or RuntimeError if error
"""
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler")
if sampler is not None:
if shuffle is not None:
raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
if num_shards is not None:
check_pos_int32(num_shards)
if shard_id is None:
raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
check_value(shard_id, [0, num_shards - 1], "shard_id")
if num_shards is None and shard_id is not None:
raise RuntimeError("shard_id is specified but num_shards is not.")
def check_padding_options(param_dict):
"""
Check for valid padded_sample and num_padded of padded samples
Args:
param_dict (dict): param_dict
Returns:
Exception: ValueError or RuntimeError if error
"""
columns_list = param_dict.get('columns_list')
block_reader = param_dict.get('block_reader')
padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded')
if padded_sample is not None:
if num_padded is None:
raise RuntimeError("padded_sample is specified and requires num_padded as well.")
if num_padded < 0:
raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded))
if columns_list is None:
raise RuntimeError("padded_sample is specified and requires columns_list as well.")
for column in columns_list:
if column not in padded_sample:
raise ValueError("padded_sample cannot match columns_list.")
if block_reader:
raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.")
if padded_sample is None and num_padded is not None:
raise RuntimeError("num_padded is specified but padded_sample is not.")
def check_num_parallel_workers(value):
type_check(value, (int,), "num_parallel_workers")
if value < 1 or value > cpu_count():
raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count()))
def check_num_samples(value):
type_check(value, (int,), "num_samples")
check_value(value, [0, INT32_MAX], "num_samples")
def validate_dataset_param_value(param_list, param_dict, param_type):
for param_name in param_list:
if param_dict.get(param_name) is not None:
if param_name == 'num_parallel_workers':
check_num_parallel_workers(param_dict.get(param_name))
if param_name == 'num_samples':
check_num_samples(param_dict.get(param_name))
else:
type_check(param_dict.get(param_name), (param_type,), param_name)
def check_gnn_list_or_ndarray(param, param_name):
"""
Check if the input parameter is list or numpy.ndarray.
Args:
param (list, nd.ndarray): param
param_name (str): param_name
Returns:
Exception: TypeError if error
"""
type_check(param, (list, np.ndarray), param_name)
if isinstance(param, list):
param_names = ["param_{0}".format(i) for i in range(len(param))]
type_check_list(param, (int,), param_names)
elif isinstance(param, np.ndarray):
if not param.dtype == np.int32:
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
param_name, param.dtype))
......@@ -98,7 +98,7 @@ class Ngram(cde.NgramOp):
"""
@check_ngram
def __init__(self, n, left_pad=None, right_pad=None, separator=None):
def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "):
super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0],
r_pad_token=right_pad[0], separator=separator)
......
......@@ -28,6 +28,7 @@ __all__ = [
"Vocab", "to_str", "to_bytes"
]
class Vocab(cde.Vocab):
"""
Vocab object that is used to lookup a word.
......@@ -38,7 +39,7 @@ class Vocab(cde.Vocab):
@classmethod
@check_from_dataset
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
special_first=None):
special_first=True):
"""
Build a vocab from a dataset.
......@@ -62,13 +63,21 @@ class Vocab(cde.Vocab):
special_tokens(list, optional): a list of strings, each one is a special token. for example
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens
is specified and special_first is set to None, special_tokens will be prepended (default=None).
is specified and special_first is set to True, special_tokens will be prepended (default=True).
Returns:
Vocab, Vocab object built from dataset.
"""
vocab = Vocab()
if columns is None:
columns = []
if not isinstance(columns, list):
columns = [columns]
if freq_range is None:
freq_range = (None, None)
if special_tokens is None:
special_tokens = []
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
for d in root.create_dict_iterator():
if d is not None:
......@@ -77,7 +86,7 @@ class Vocab(cde.Vocab):
@classmethod
@check_from_list
def from_list(cls, word_list, special_tokens=None, special_first=None):
def from_list(cls, word_list, special_tokens=None, special_first=True):
"""
Build a vocab object from a list of word.
......@@ -86,29 +95,33 @@ class Vocab(cde.Vocab):
special_tokens(list, optional): a list of strings, each one is a special token. for example
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
is specified and special_first is set to None, special_tokens will be prepended (default=None).
is specified and special_first is set to True, special_tokens will be prepended (default=True).
"""
if special_tokens is None:
special_tokens = []
return super().from_list(word_list, special_tokens, special_first)
@classmethod
@check_from_file
def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None):
def from_file(cls, file_path, delimiter="", vocab_size=None, special_tokens=None, special_first=True):
"""
Build a vocab object from a list of word.
Args:
file_path (str): path to the file which contains the vocab list.
delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be
the word (default=None).
the word (default="").
vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken).
special_tokens (list, optional): a list of strings, each one is a special token. for example
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
special_first (bool, optional): whether special_tokens will be prepended/appended to vocab,
If special_tokens is specified and special_first is set to None,
special_tokens will be prepended (default=None).
If special_tokens is specified and special_first is set to True,
special_tokens will be prepended (default=True).
"""
if vocab_size is None:
vocab_size = -1
if special_tokens is None:
special_tokens = []
return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
@classmethod
......
......@@ -17,23 +17,22 @@ validators for text ops
"""
from functools import wraps
import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype
import mindspore._c_dataengine as cde
from mindspore._c_expression import typing
from ..transforms.validators import check_uint32, check_pos_int64
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, check_positive, \
INT32_MAX, check_value
def check_unique_list_of_words(words, arg_name):
"""Check that words is a list and each element is a str without any duplication"""
if not isinstance(words, list):
raise ValueError(arg_name + " needs to be a list of words of type string.")
type_check(words, (list,), arg_name)
words_set = set()
for word in words:
if not isinstance(word, str):
raise ValueError("each word in " + arg_name + " needs to be type str.")
type_check(word, (str,), arg_name)
if word in words_set:
raise ValueError(arg_name + " contains duplicate word: " + word + ".")
words_set.add(word)
......@@ -45,21 +44,14 @@ def check_lookup(method):
@wraps(method)
def new_method(self, *args, **kwargs):
vocab, unknown = (list(args) + 2 * [None])[:2]
if "vocab" in kwargs:
vocab = kwargs.get("vocab")
if "unknown" in kwargs:
unknown = kwargs.get("unknown")
if unknown is not None:
if not (isinstance(unknown, int) and unknown >= 0):
raise ValueError("unknown needs to be a non-negative integer.")
[vocab, unknown], _ = parse_user_args(method, *args, **kwargs)
if not isinstance(vocab, cde.Vocab):
raise ValueError("vocab is not an instance of cde.Vocab.")
if unknown is not None:
type_check(unknown, (int,), "unknown")
check_positive(unknown)
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
kwargs["vocab"] = vocab
kwargs["unknown"] = unknown
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -69,50 +61,15 @@ def check_from_file(method):
@wraps(method)
def new_method(self, *args, **kwargs):
file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5]
if "file_path" in kwargs:
file_path = kwargs.get("file_path")
if "delimiter" in kwargs:
delimiter = kwargs.get("delimiter")
if "vocab_size" in kwargs:
vocab_size = kwargs.get("vocab_size")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if not isinstance(file_path, str):
raise ValueError("file_path needs to be str.")
if delimiter is not None:
if not isinstance(delimiter, str):
raise ValueError("delimiter needs to be str.")
else:
delimiter = ""
if vocab_size is not None:
if not (isinstance(vocab_size, int) and vocab_size > 0):
raise ValueError("vocab size needs to be a positive integer.")
else:
vocab_size = -1
if special_first is None:
special_first = True
if not isinstance(special_first, bool):
raise ValueError("special_first needs to be a boolean value")
if special_tokens is None:
special_tokens = []
[file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
**kwargs)
check_unique_list_of_words(special_tokens, "special_tokens")
type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
if vocab_size is not None:
check_value(vocab_size, (-1, INT32_MAX), "vocab_size")
type_check(special_first, (bool,), special_first)
kwargs["file_path"] = file_path
kwargs["delimiter"] = delimiter
kwargs["vocab_size"] = vocab_size
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -122,33 +79,20 @@ def check_from_list(method):
@wraps(method)
def new_method(self, *args, **kwargs):
word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3]
if "word_list" in kwargs:
word_list = kwargs.get("word_list")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if special_tokens is None:
special_tokens = []
word_set = check_unique_list_of_words(word_list, "word_list")
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
[word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs)
intersect = word_set.intersection(token_set)
word_set = check_unique_list_of_words(word_list, "word_list")
if special_tokens is not None:
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
if intersect != set():
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
intersect = word_set.intersection(token_set)
if special_first is None:
special_first = True
if intersect != set():
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
if not isinstance(special_first, bool):
raise ValueError("special_first needs to be a boolean value.")
type_check(special_first, (bool,), "special_first")
kwargs["word_list"] = word_list
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -158,18 +102,15 @@ def check_from_dict(method):
@wraps(method)
def new_method(self, *args, **kwargs):
word_dict, = (list(args) + [None])[:1]
if "word_dict" in kwargs:
word_dict = kwargs.get("word_dict")
if not isinstance(word_dict, dict):
raise ValueError("word_dict needs to be a list of word,id pairs.")
[word_dict], _ = parse_user_args(method, *args, **kwargs)
type_check(word_dict, (dict,), "word_dict")
for word, word_id in word_dict.items():
if not isinstance(word, str):
raise ValueError("Each word in word_dict needs to be type string.")
if not (isinstance(word_id, int) and word_id >= 0):
raise ValueError("Each word id needs to be positive integer.")
kwargs["word_dict"] = word_dict
return method(self, **kwargs)
type_check(word, (str,), "word")
type_check(word_id, (int,), "word_id")
check_value(word_id, (-1, INT32_MAX), "word_id")
return method(self, *args, **kwargs)
return new_method
......@@ -179,23 +120,8 @@ def check_jieba_init(method):
@wraps(method)
def new_method(self, *args, **kwargs):
hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
if "hmm_path" in kwargs:
hmm_path = kwargs.get("hmm_path")
if "mp_path" in kwargs:
mp_path = kwargs.get("mp_path")
if hmm_path is None:
raise ValueError(
"The dict of HMMSegment in cppjieba is not provided.")
kwargs["hmm_path"] = hmm_path
if mp_path is None:
raise ValueError(
"The dict of MPSegment in cppjieba is not provided.")
kwargs["mp_path"] = mp_path
if model is not None:
kwargs["model"] = model
return method(self, **kwargs)
parse_user_args(method, *args, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -205,19 +131,12 @@ def check_jieba_add_word(method):
@wraps(method)
def new_method(self, *args, **kwargs):
word, freq = (list(args) + 2 * [None])[:2]
if "word" in kwargs:
word = kwargs.get("word")
if "freq" in kwargs:
freq = kwargs.get("freq")
[word, freq], _ = parse_user_args(method, *args, **kwargs)
if word is None:
raise ValueError("word is not provided.")
kwargs["word"] = word
if freq is not None:
check_uint32(freq)
kwargs["freq"] = freq
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -227,13 +146,8 @@ def check_jieba_add_dict(method):
@wraps(method)
def new_method(self, *args, **kwargs):
user_dict = (list(args) + [None])[0]
if "user_dict" in kwargs:
user_dict = kwargs.get("user_dict")
if user_dict is None:
raise ValueError("user_dict is not provided.")
kwargs["user_dict"] = user_dict
return method(self, **kwargs)
parse_user_args(method, *args, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -244,69 +158,39 @@ def check_from_dataset(method):
@wraps(method)
def new_method(self, *args, **kwargs):
dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6]
if "dataset" in kwargs:
dataset = kwargs.get("dataset")
if "columns" in kwargs:
columns = kwargs.get("columns")
if "freq_range" in kwargs:
freq_range = kwargs.get("freq_range")
if "top_k" in kwargs:
top_k = kwargs.get("top_k")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if columns is None:
columns = []
if not isinstance(columns, list):
columns = [columns]
for column in columns:
if not isinstance(column, str):
raise ValueError("columns need to be a list of strings.")
if freq_range is None:
freq_range = (None, None)
if not isinstance(freq_range, tuple) or len(freq_range) != 2:
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
[_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args,
**kwargs)
if columns is not None:
if not isinstance(columns, list):
columns = [columns]
col_names = ["col_{0}".format(i) for i in range(len(columns))]
type_check_list(columns, (str,), col_names)
for num in freq_range:
if num is not None and (not isinstance(num, int)):
raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
if freq_range is not None:
type_check(freq_range, (tuple,), "freq_range")
if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
if freq_range[0] > freq_range[1] or freq_range[0] < 0:
raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
if len(freq_range) != 2:
raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.")
if top_k is not None and (not isinstance(top_k, int)):
raise ValueError("top_k needs to be a positive integer.")
for num in freq_range:
if num is not None and (not isinstance(num, int)):
raise ValueError(
"freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
if isinstance(top_k, int) and top_k <= 0:
raise ValueError("top_k needs to be a positive integer.")
if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
if freq_range[0] > freq_range[1] or freq_range[0] < 0:
raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
if special_first is None:
special_first = True
type_check(top_k, (int, type(None)), "top_k")
if special_tokens is None:
special_tokens = []
if isinstance(top_k, int):
check_value(top_k, (0, INT32_MAX), "top_k")
type_check(special_first, (bool,), "special_first")
if not isinstance(special_first, bool):
raise ValueError("special_first needs to be a boolean value.")
if special_tokens is not None:
check_unique_list_of_words(special_tokens, "special_tokens")
check_unique_list_of_words(special_tokens, "special_tokens")
kwargs["dataset"] = dataset
kwargs["columns"] = columns
kwargs["freq_range"] = freq_range
kwargs["top_k"] = top_k
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -316,15 +200,7 @@ def check_ngram(method):
@wraps(method)
def new_method(self, *args, **kwargs):
n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4]
if "n" in kwargs:
n = kwargs.get("n")
if "left_pad" in kwargs:
left_pad = kwargs.get("left_pad")
if "right_pad" in kwargs:
right_pad = kwargs.get("right_pad")
if "separator" in kwargs:
separator = kwargs.get("separator")
[n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs)
if isinstance(n, int):
n = [n]
......@@ -332,15 +208,9 @@ def check_ngram(method):
if not (isinstance(n, list) and n != []):
raise ValueError("n needs to be a non-empty list of positive integers.")
for gram in n:
if not (isinstance(gram, int) and gram > 0):
raise ValueError("n in ngram needs to be a positive number.")
if left_pad is None:
left_pad = ("", 0)
if right_pad is None:
right_pad = ("", 0)
for i, gram in enumerate(n):
type_check(gram, (int,), "gram[{0}]".format(i))
check_value(gram, (0, INT32_MAX), "gram_{}".format(i))
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
left_pad[1], int)):
......@@ -353,11 +223,7 @@ def check_ngram(method):
if not (left_pad[1] >= 0 and right_pad[1] >= 0):
raise ValueError("padding width need to be positive numbers.")
if separator is None:
separator = " "
if not isinstance(separator, str):
raise ValueError("separator needs to be a string.")
type_check(separator, (str,), "separator")
kwargs["n"] = n
kwargs["left_pad"] = left_pad
......@@ -374,16 +240,8 @@ def check_pair_truncate(method):
@wraps(method)
def new_method(self, *args, **kwargs):
max_length = (list(args) + [None])[0]
if "max_length" in kwargs:
max_length = kwargs.get("max_length")
if max_length is None:
raise ValueError("max_length is not provided.")
check_pos_int64(max_length)
kwargs["max_length"] = max_length
return method(self, **kwargs)
parse_user_args(method, *args, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -393,22 +251,13 @@ def check_to_number(method):
@wraps(method)
def new_method(self, *args, **kwargs):
data_type = (list(args) + [None])[0]
if "data_type" in kwargs:
data_type = kwargs.get("data_type")
if data_type is None:
raise ValueError("data_type is a mandatory parameter but was not provided.")
if not isinstance(data_type, typing.Type):
raise TypeError("data_type is not a MindSpore data type.")
[data_type], _ = parse_user_args(method, *args, **kwargs)
type_check(data_type, (typing.Type,), "data_type")
if data_type not in mstype.number_type:
raise TypeError("data_type is not numeric data type.")
kwargs["data_type"] = data_type
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -418,18 +267,11 @@ def check_python_tokenizer(method):
@wraps(method)
def new_method(self, *args, **kwargs):
tokenizer = (list(args) + [None])[0]
if "tokenizer" in kwargs:
tokenizer = kwargs.get("tokenizer")
if tokenizer is None:
raise ValueError("tokenizer is a mandatory parameter.")
[tokenizer], _ = parse_user_args(method, *args, **kwargs)
if not callable(tokenizer):
raise TypeError("tokenizer is not a callable python function")
kwargs["tokenizer"] = tokenizer
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -18,6 +18,7 @@ from functools import wraps
import numpy as np
from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
......@@ -37,106 +38,33 @@ DOUBLE_MAX_INTEGER = 9007199254740992
DOUBLE_MIN_INTEGER = -9007199254740992
def check_type(value, valid_type):
if not isinstance(value, valid_type):
raise ValueError("Wrong input type")
def check_value(value, valid_range):
if value < valid_range[0] or value > valid_range[1]:
raise ValueError("Input is not within the required range")
def check_range(values, valid_range):
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
raise ValueError("Input range is not valid")
def check_positive(value):
if value <= 0:
raise ValueError("Input must greater than 0")
def check_positive_float(value, valid_max=None):
if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max):
raise ValueError("Input need to be a valid positive float.")
def check_bool(value):
if not isinstance(value, bool):
raise ValueError("Value needs to be a boolean.")
def check_2tuple(value):
if not (isinstance(value, tuple) and len(value) == 2):
raise ValueError("Value needs to be a 2-tuple.")
def check_list(value):
if not isinstance(value, list):
raise ValueError("The input needs to be a list.")
def check_uint8(value):
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [UINT8_MIN, UINT8_MAX])
def check_uint32(value):
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [UINT32_MIN, UINT32_MAX])
def check_pos_int32(value):
"""Checks for int values starting from 1"""
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [POS_INT_MIN, INT32_MAX])
def check_uint64(value):
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [UINT64_MIN, UINT64_MAX])
def check_pos_int64(value):
if not isinstance(value, int):
raise ValueError("The input needs to be a integer")
check_value(value, [UINT64_MIN, INT64_MAX])
def check_fill_value(method):
"""Wrapper method to check the parameters of fill_value."""
def check_pos_float32(value):
check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER])
@wraps(method)
def new_method(self, *args, **kwargs):
[fill_value], _ = parse_user_args(method, *args, **kwargs)
type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
return method(self, *args, **kwargs)
def check_pos_float64(value):
check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER])
return new_method
def check_one_hot_op(method):
"""Wrapper method to check the parameters of one hot op."""
"""Wrapper method to check the parameters of one_hot_op."""
@wraps(method)
def new_method(self, *args, **kwargs):
args = (list(args) + 2 * [None])[:2]
num_classes, smoothing_rate = args
if "num_classes" in kwargs:
num_classes = kwargs.get("num_classes")
if "smoothing_rate" in kwargs:
smoothing_rate = kwargs.get("smoothing_rate")
if num_classes is None:
raise ValueError("num_classes")
check_pos_int32(num_classes)
kwargs["num_classes"] = num_classes
[num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
type_check(num_classes, (int,), "num_classes")
check_positive(num_classes)
if smoothing_rate is not None:
check_value(smoothing_rate, [0., 1.])
kwargs["smoothing_rate"] = smoothing_rate
check_value(smoothing_rate, [0., 1.], "smoothing_rate")
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -146,35 +74,12 @@ def check_num_classes(method):
@wraps(method)
def new_method(self, *args, **kwargs):
num_classes = (list(args) + [None])[0]
if "num_classes" in kwargs:
num_classes = kwargs.get("num_classes")
if num_classes is None:
raise ValueError("num_classes is not provided.")
check_pos_int32(num_classes)
kwargs["num_classes"] = num_classes
return method(self, **kwargs)
return new_method
[num_classes], _ = parse_user_args(method, *args, **kwargs)
def check_fill_value(method):
"""Wrapper method to check the parameters of fill value."""
@wraps(method)
def new_method(self, *args, **kwargs):
fill_value = (list(args) + [None])[0]
if "fill_value" in kwargs:
fill_value = kwargs.get("fill_value")
if fill_value is None:
raise ValueError("fill_value is not provided.")
if not isinstance(fill_value, (str, float, bool, int, bytes)):
raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int")
kwargs["fill_value"] = fill_value
type_check(num_classes, (int,), "num_classes")
check_positive(num_classes)
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -184,17 +89,11 @@ def check_de_type(method):
@wraps(method)
def new_method(self, *args, **kwargs):
data_type = (list(args) + [None])[0]
if "data_type" in kwargs:
data_type = kwargs.get("data_type")
[data_type], _ = parse_user_args(method, *args, **kwargs)
if data_type is None:
raise ValueError("data_type is not provided.")
if not isinstance(data_type, typing.Type):
raise TypeError("data_type is not a MindSpore data type.")
kwargs["data_type"] = data_type
type_check(data_type, (typing.Type,), "data_type")
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -204,13 +103,11 @@ def check_slice_op(method):
@wraps(method)
def new_method(self, *args):
for i, arg in enumerate(args):
if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)):
raise TypeError("Indexing of dim " + str(i) + "is not of valid type")
for _, arg in enumerate(args):
type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg")
if isinstance(arg, list):
for a in arg:
if not isinstance(a, int):
raise TypeError("Index " + a + " is not an int")
type_check(a, (int,), "a")
return method(self, *args)
return new_method
......@@ -221,36 +118,14 @@ def check_mask_op(method):
@wraps(method)
def new_method(self, *args, **kwargs):
operator, constant, dtype = (list(args) + 3 * [None])[:3]
if "operator" in kwargs:
operator = kwargs.get("operator")
if "constant" in kwargs:
constant = kwargs.get("constant")
if "dtype" in kwargs:
dtype = kwargs.get("dtype")
if operator is None:
raise ValueError("operator is not provided.")
if constant is None:
raise ValueError("constant is not provided.")
[operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
from .c_transforms import Relational
if not isinstance(operator, Relational):
raise TypeError("operator is not a Relational operator enum.")
type_check(operator, (Relational,), "operator")
type_check(constant, (str, float, bool, int, bytes), "constant")
type_check(dtype, (typing.Type,), "dtype")
if not isinstance(constant, (str, float, bool, int, bytes)):
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
if dtype is not None:
if not isinstance(dtype, typing.Type):
raise TypeError("dtype is not a MindSpore data type.")
kwargs["dtype"] = dtype
kwargs["operator"] = operator
kwargs["constant"] = constant
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -260,22 +135,12 @@ def check_pad_end(method):
@wraps(method)
def new_method(self, *args, **kwargs):
pad_shape, pad_value = (list(args) + 2 * [None])[:2]
if "pad_shape" in kwargs:
pad_shape = kwargs.get("pad_shape")
if "pad_value" in kwargs:
pad_value = kwargs.get("pad_value")
if pad_shape is None:
raise ValueError("pad_shape is not provided.")
[pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
if pad_value is not None:
if not isinstance(pad_value, (str, float, bool, int, bytes)):
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes")
kwargs["pad_value"] = pad_value
if not isinstance(pad_shape, list):
raise TypeError("pad_shape must be a list")
type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
type_check(pad_shape, (list,), "pad_end")
for dim in pad_shape:
if dim is not None:
......@@ -284,9 +149,7 @@ def check_pad_end(method):
else:
raise TypeError("a value in the list is not an integer.")
kwargs["pad_shape"] = pad_shape
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -296,31 +159,24 @@ def check_concat_type(method):
@wraps(method)
def new_method(self, *args, **kwargs):
axis, prepend, append = (list(args) + 3 * [None])[:3]
if "prepend" in kwargs:
prepend = kwargs.get("prepend")
if "append" in kwargs:
append = kwargs.get("append")
if "axis" in kwargs:
axis = kwargs.get("axis")
[axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
if axis is not None:
if not isinstance(axis, int):
raise TypeError("axis type is not valid, must be an integer.")
type_check(axis, (int,), "axis")
if axis not in (0, -1):
raise ValueError("only 1D concatenation supported.")
kwargs["axis"] = axis
if prepend is not None:
if not isinstance(prepend, (type(None), np.ndarray)):
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
kwargs["prepend"] = prepend
type_check(prepend, (np.ndarray,), "prepend")
if len(prepend.shape) != 1:
raise ValueError("can only prepend 1D arrays.")
if append is not None:
if not isinstance(append, (type(None), np.ndarray)):
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
kwargs["append"] = append
type_check(append, (np.ndarray,), "append")
if len(append.shape) != 1:
raise ValueError("can only append 1D arrays.")
return method(self, **kwargs)
return method(self, *args, **kwargs)
return new_method
......@@ -40,12 +40,14 @@ Examples:
>>> dataset = dataset.map(input_columns="image", operations=transforms_list)
>>> dataset = dataset.map(input_columns="label", operations=onehot_op)
"""
import numbers
import mindspore._c_dataengine as cde
from .utils import Inter, Border
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \
FLOAT_MAX_INTEGER
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
......@@ -57,6 +59,18 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT,
Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC}
def parse_padding(padding):
if isinstance(padding, numbers.Number):
padding = [padding] * 4
if len(padding) == 2:
left = right = padding[0]
top = bottom = padding[1]
padding = (left, top, right, bottom,)
if isinstance(padding, list):
padding = tuple(padding)
return padding
class Decode(cde.DecodeOp):
"""
Decode the input image in RGB mode.
......@@ -136,16 +150,22 @@ class RandomCrop(cde.RandomCropOp):
@check_random_crop
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill_value = fill_value
self.padding_mode = padding_mode.value
if isinstance(size, int):
size = (size, size)
if padding is None:
padding = (0, 0, 0, 0)
else:
padding = parse_padding(padding)
if isinstance(fill_value, int): # temporary fix
fill_value = tuple([fill_value] * 3)
border_type = DE_C_BORDER_TYPE[padding_mode]
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill_value = fill_value
self.padding_mode = padding_mode.value
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
......@@ -184,16 +204,23 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp):
@check_random_crop
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill_value = fill_value
self.padding_mode = padding_mode.value
if isinstance(size, int):
size = (size, size)
if padding is None:
padding = (0, 0, 0, 0)
else:
padding = parse_padding(padding)
if isinstance(fill_value, int): # temporary fix
fill_value = tuple([fill_value] * 3)
border_type = DE_C_BORDER_TYPE[padding_mode]
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill_value = fill_value
self.padding_mode = padding_mode.value
super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value)
......@@ -292,6 +319,8 @@ class Resize(cde.ResizeOp):
@check_resize_interpolation
def __init__(self, size, interpolation=Inter.LINEAR):
if isinstance(size, int):
size = (size, size)
self.size = size
self.interpolation = interpolation
interpoltn = DE_C_INTER_MODE[interpolation]
......@@ -359,6 +388,8 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp):
@check_random_resize_crop
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation=Inter.BILINEAR, max_attempts=10):
if isinstance(size, int):
size = (size, size)
self.size = size
self.scale = scale
self.ratio = ratio
......@@ -396,6 +427,8 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp):
@check_random_resize_crop
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation=Inter.BILINEAR, max_attempts=10):
if isinstance(size, int):
size = (size, size)
self.size = size
self.scale = scale
self.ratio = ratio
......@@ -417,6 +450,8 @@ class CenterCrop(cde.CenterCropOp):
@check_crop
def __init__(self, size):
if isinstance(size, int):
size = (size, size)
self.size = size
super().__init__(*size)
......@@ -442,12 +477,26 @@ class RandomColorAdjust(cde.RandomColorAdjustOp):
@check_random_color_adjust
def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)):
brightness = self.expand_values(brightness)
contrast = self.expand_values(contrast)
saturation = self.expand_values(saturation)
hue = self.expand_values(hue, center=0, bound=(-0.5, 0.5), non_negative=False)
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
super().__init__(*brightness, *contrast, *saturation, *hue)
def expand_values(self, value, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
if isinstance(value, numbers.Number):
value = [center - value, center + value]
if non_negative:
value[0] = max(0, value[0])
check_range(value, bound)
return (value[0], value[1])
class RandomRotation(cde.RandomRotationOp):
"""
......@@ -485,6 +534,8 @@ class RandomRotation(cde.RandomRotationOp):
self.expand = expand
self.center = center
self.fill_value = fill_value
if isinstance(degrees, numbers.Number):
degrees = (-degrees, degrees)
if center is None:
center = (-1, -1)
if isinstance(fill_value, int): # temporary fix
......@@ -584,6 +635,8 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp):
@check_random_resize_crop
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation=Inter.BILINEAR, max_attempts=10):
if isinstance(size, int):
size = (size, size)
self.size = size
self.scale = scale
self.ratio = ratio
......@@ -623,12 +676,14 @@ class Pad(cde.PadOp):
@check_pad
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
self.padding = padding
self.fill_value = fill_value
self.padding_mode = padding_mode
padding = parse_padding(padding)
if isinstance(fill_value, int): # temporary fix
fill_value = tuple([fill_value] * 3)
padding_mode = DE_C_BORDER_TYPE[padding_mode]
self.padding = padding
self.fill_value = fill_value
self.padding_mode = padding_mode
super().__init__(*padding, padding_mode, *fill_value)
......
......@@ -28,6 +28,7 @@ import numpy as np
from PIL import Image
from . import py_transforms_util as util
from .c_transforms import parse_padding
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \
......@@ -295,6 +296,10 @@ class RandomCrop:
@check_random_crop
def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT):
if padding is None:
padding = (0, 0, 0, 0)
else:
padding = parse_padding(padding)
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
......@@ -753,6 +758,8 @@ class TenCrop:
@check_ten_crop
def __init__(self, size, use_vertical_flip=False):
if isinstance(size, int):
size = (size, size)
self.size = size
self.use_vertical_flip = use_vertical_flip
......@@ -877,6 +884,8 @@ class Pad:
@check_pad
def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT):
parse_padding(padding)
self.padding = padding
self.fill_value = fill_value
self.padding_mode = DE_PY_BORDER_TYPE[padding_mode]
......@@ -1129,56 +1138,23 @@ class RandomAffine:
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0):
# Parameter checking
# rotation
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
self.degrees = degrees
else:
raise TypeError("If degrees is a list or tuple, it must be of length 2.")
# translation
if translate is not None:
if isinstance(translate, (tuple, list)) and len(translate) == 2:
for t in translate:
if t < 0.0 or t > 1.0:
raise ValueError("translation values should be between 0 and 1")
else:
raise TypeError("translate should be a list or tuple of length 2.")
self.translate = translate
# scale
if scale is not None:
if isinstance(scale, (tuple, list)) and len(scale) == 2:
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
else:
raise TypeError("scale should be a list or tuple of length 2.")
self.scale_ranges = scale
# shear
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-1 * shear, shear)
elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4):
# X-Axis shear with [min, max]
shear = (-1 * shear, shear)
else:
if len(shear) == 2:
self.shear = [shear[0], shear[1], 0., 0.]
shear = [shear[0], shear[1], 0., 0.]
elif len(shear) == 4:
self.shear = [s for s in shear]
else:
raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.")
else:
self.shear = shear
shear = [s for s in shear]
# resample
self.resample = DE_PY_INTER_MODE[resample]
if isinstance(degrees, numbers.Number):
degrees = (-degrees, degrees)
# fill_value
self.degrees = degrees
self.translate = translate
self.scale_ranges = scale
self.shear = shear
self.resample = DE_PY_INTER_MODE[resample]
self.fill_value = fill_value
def __call__(self, img):
......
......@@ -15,13 +15,15 @@
"""
Testing the bounding box augment op in DE
"""
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
import numpy as np
import mindspore.log as logger
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
GENERATE_GOLDEN = False
# updated VOC dataset with correct annotations
......@@ -241,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c():
operations=[test_op]) # Add column for "annotation"
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input is not" in str(error)
assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error)
def test_bounding_box_augment_invalid_bounds_c():
......
......@@ -17,6 +17,7 @@ import pytest
import numpy as np
import mindspore.dataset as ds
# generates 1 column [0], [0, 1], ..., [0, ..., n-1]
def generate_sequential(n):
for i in range(n):
......@@ -99,12 +100,12 @@ def test_bucket_batch_invalid_input():
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, invalid_type_pad_to_bucket_boundary)
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value)
assert "Argument pad_to_bucket_boundary with value \"\" is not of type (<class \'bool\'>,)." in str(info.value)
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, False, invalid_type_drop_remainder)
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value)
assert "Argument drop_remainder with value \"\" is not of type (<class 'bool'>,)." in str(info.value)
def test_bucket_batch_multi_bucket_no_padding():
......@@ -272,7 +273,6 @@ def test_bucket_batch_default_pad():
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]]
output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())
......
......@@ -163,18 +163,11 @@ def test_concatenate_op_negative_axis():
def test_concatenate_op_incorrect_input_dim():
def gen():
yield (np.array(["ss", "ad"], dtype='S'),)
prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S')
data = ds.GeneratorDataset(gen, column_names=["col"])
concatenate_op = data_trans.Concatenate(0, prepend_tensor)
data = data.map(input_columns=["col"], operations=concatenate_op)
with pytest.raises(RuntimeError) as error_info:
for _ in data:
pass
assert "Only 1D tensors supported" in repr(error_info.value)
with pytest.raises(ValueError) as error_info:
data_trans.Concatenate(0, prepend_tensor)
assert "can only prepend 1D arrays." in repr(error_info.value)
if __name__ == "__main__":
......
......@@ -28,9 +28,9 @@ def test_exception_01():
"""
logger.info("test_exception_01")
data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"])
with pytest.raises(ValueError) as info:
data = data.map(input_columns=["image"], operations=vision.Resize(100, 100))
assert "Invalid interpolation mode." in str(info.value)
with pytest.raises(TypeError) as info:
data.map(input_columns=["image"], operations=vision.Resize(100, 100))
assert "Argument interpolation with value 100 is not of type (<enum 'Inter'>,)" in str(info.value)
def test_exception_02():
......@@ -40,8 +40,8 @@ def test_exception_02():
logger.info("test_exception_02")
num_samples = -1
with pytest.raises(ValueError) as info:
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
assert "num_samples cannot be less than 0" in str(info.value)
ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(info.value)
num_samples = 1
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
......
......@@ -23,7 +23,8 @@ import mindspore.dataset.text as text
def test_demo_basic_from_dataset():
""" this is a tutorial on how from_dataset should be used in a normal use case"""
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None,
special_tokens=["<pad>", "<unk>"],
special_first=True)
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
res = []
......@@ -127,15 +128,16 @@ def test_from_dataset_exceptions():
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
assert isinstance(vocab.text.Vocab)
except ValueError as e:
except (TypeError, ValueError, RuntimeError) as e:
assert s in str(e), str(e)
test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers")
test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer")
test_config(23, (2, 3), 1.2345, "columns need to be a list of strings")
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b")
test_config("text", (2, 3), 0, "top_k needs to be a positive integer")
test_config([123], (2, 3), 0, "columns need to be a list of strings")
test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.")
test_config("text", (2, 3), 1.2345,
"Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)")
test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)")
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")
test_config("text", (2, 3), 0, "top_k needs to be positive number")
test_config([123], (2, 3), 0, "top_k needs to be positive number")
if __name__ == '__main__':
......
......@@ -73,6 +73,7 @@ def test_linear_transformation_op(plot=False):
if plot:
visualize_list(image, image_transformed)
def test_linear_transformation_md5():
"""
Test LinearTransformation op: valid params (transformation_matrix, mean_vector)
......@@ -102,6 +103,7 @@ def test_linear_transformation_md5():
filename = "linear_transformation_01_result.npz"
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_linear_transformation_exception_01():
"""
Test LinearTransformation op: transformation_matrix is not provided
......@@ -126,9 +128,10 @@ def test_linear_transformation_exception_01():
]
transform = py_vision.ComposeOp(transforms)
data1 = data1.map(input_columns=["image"], operations=transform())
except ValueError as e:
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "not provided" in str(e)
assert "Argument transformation_matrix with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
def test_linear_transformation_exception_02():
"""
......@@ -154,9 +157,10 @@ def test_linear_transformation_exception_02():
]
transform = py_vision.ComposeOp(transforms)
data1 = data1.map(input_columns=["image"], operations=transform())
except ValueError as e:
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "not provided" in str(e)
assert "Argument mean_vector with value None is not of type (<class 'numpy.ndarray'>,)" in str(e)
def test_linear_transformation_exception_03():
"""
......@@ -187,6 +191,7 @@ def test_linear_transformation_exception_03():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "square matrix" in str(e)
def test_linear_transformation_exception_04():
"""
Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix
......@@ -199,7 +204,7 @@ def test_linear_transformation_exception_04():
weight = 50
dim = 3 * height * weight
transformation_matrix = np.ones([dim, dim])
mean_vector = np.zeros(dim-1)
mean_vector = np.zeros(dim - 1)
# Generate dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
......@@ -216,6 +221,7 @@ def test_linear_transformation_exception_04():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "should match" in str(e)
if __name__ == '__main__':
test_linear_transformation_op(plot=True)
test_linear_transformation_md5()
......
......@@ -184,24 +184,26 @@ def test_minddataset_invalidate_num_shards():
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "):
with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_minddataset_invalidate_shard_id():
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "):
with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
......@@ -210,17 +212,19 @@ def test_minddataset_shard_id_bigger_than_num_shard():
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="shard_id is invalid, "):
with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
with pytest.raises(Exception, match="shard_id is invalid, "):
with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
......@@ -15,9 +15,9 @@
"""
Testing Ngram in mindspore.dataset
"""
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text
import numpy as np
def test_multiple_ngrams():
......@@ -61,7 +61,7 @@ def test_simple_ngram():
yield (np.array(line.split(" "), dtype='S'),)
dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None))
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=" "))
i = 0
for data in dataset.create_dict_iterator():
......@@ -72,7 +72,7 @@ def test_simple_ngram():
def test_corner_cases():
""" testing various corner cases and exceptions"""
def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None):
def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
def gen(texts):
yield (np.array(texts.split(" "), dtype='S'),)
......@@ -93,7 +93,7 @@ def test_corner_cases():
try:
test_config("Yours to Discover", "", [0, [1]])
except Exception as e:
assert "ngram needs to be a positive number" in str(e)
assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e)
# test empty n
try:
test_config("Yours to Discover", "", [])
......
......@@ -279,7 +279,7 @@ def test_normalize_exception_invalid_range_py():
_ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not within the required range" in str(e)
assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e)
def test_normalize_grayscale_md5_01():
......
......@@ -61,6 +61,10 @@ def test_pad_end_exceptions():
pad_compare([3, 4, 5], ["2"], 1, [])
assert "a value in the list is not an integer." in str(info.value)
with pytest.raises(TypeError) as info:
pad_compare([1, 2], 3, -1, [1, 2, -1])
assert "Argument pad_end with value 3 is not of type (<class 'list'>,)" in str(info.value)
if __name__ == "__main__":
test_pad_end_basics()
......
......@@ -103,7 +103,7 @@ def test_random_affine_exception_negative_degrees():
_ = py_vision.RandomAffine(degrees=-15)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "If degrees is a single number, it cannot be negative."
assert str(e) == "Input degrees is not within the required interval of (0 to inf)."
def test_random_affine_exception_translation_range():
......@@ -115,7 +115,7 @@ def test_random_affine_exception_translation_range():
_ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "translation values should be between 0 and 1"
assert str(e) == "Input translate at 1 is not within the required interval of (0.0 to 1.0)."
def test_random_affine_exception_scale_value():
......@@ -127,7 +127,7 @@ def test_random_affine_exception_scale_value():
_ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "scale values should be positive"
assert str(e) == "Input scale[0] must be greater than 0."
def test_random_affine_exception_shear_value():
......@@ -139,7 +139,7 @@ def test_random_affine_exception_shear_value():
_ = py_vision.RandomAffine(degrees=15, shear=-5)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "If shear is a single number, it must be positive."
assert str(e) == "Input shear must be greater than 0."
def test_random_affine_exception_degrees_size():
......@@ -165,7 +165,9 @@ def test_random_affine_exception_translate_size():
_ = py_vision.RandomAffine(degrees=15, translate=(0.1))
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "translate should be a list or tuple of length 2."
assert str(
e) == "Argument translate with value 0.1 is not of type (<class 'list'>," \
" <class 'tuple'>)."
def test_random_affine_exception_scale_size():
......@@ -178,7 +180,8 @@ def test_random_affine_exception_scale_size():
_ = py_vision.RandomAffine(degrees=15, scale=(0.5))
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "scale should be a list or tuple of length 2."
assert str(e) == "Argument scale with value 0.5 is not of type (<class 'tuple'>," \
" <class 'list'>)."
def test_random_affine_exception_shear_size():
......@@ -191,7 +194,7 @@ def test_random_affine_exception_shear_size():
_ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10))
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4."
assert str(e) == "shear must be of length 2 or 4."
if __name__ == "__main__":
......
......@@ -97,7 +97,7 @@ def test_random_color_md5():
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms = F.ComposeOp([F.Decode(),
F.RandomColor((0.5, 1.5)),
F.RandomColor((0.1, 1.9)),
F.ToTensor()])
data = data.map(input_columns="image", operations=transforms())
......
......@@ -232,7 +232,7 @@ def test_random_crop_and_resize_04_c():
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input range is not valid" in str(e)
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
def test_random_crop_and_resize_04_py():
......@@ -255,7 +255,7 @@ def test_random_crop_and_resize_04_py():
data = data.map(input_columns=["image"], operations=transform())
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input range is not valid" in str(e)
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
def test_random_crop_and_resize_05_c():
......@@ -275,7 +275,7 @@ def test_random_crop_and_resize_05_c():
data = data.map(input_columns=["image"], operations=random_crop_and_resize_op)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input range is not valid" in str(e)
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
def test_random_crop_and_resize_05_py():
......@@ -298,7 +298,7 @@ def test_random_crop_and_resize_05_py():
data = data.map(input_columns=["image"], operations=transform())
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input range is not valid" in str(e)
assert "Input is not within the required interval of (0 to 16777216)." in str(e)
def test_random_crop_and_resize_comp(plot=False):
......
......@@ -159,7 +159,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input range is not valid" in str(err)
assert "Input is not within the required interval of (0 to 16777216)." in str(err)
def test_random_resized_crop_with_bbox_op_invalid2_c():
......@@ -185,7 +185,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input range is not valid" in str(err)
assert "Input is not within the required interval of (0 to 16777216)." in str(err)
def test_random_resized_crop_with_bbox_op_bad_c():
......
......@@ -179,7 +179,7 @@ def test_random_grayscale_invalid_param():
data = data.map(input_columns=["image"], operations=transform())
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not within the required range" in str(e)
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
if __name__ == "__main__":
test_random_grayscale_valid_prob(True)
......
......@@ -141,7 +141,7 @@ def test_random_horizontal_invalid_prob_c():
data = data.map(input_columns=["image"], operations=random_horizontal_op)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e)
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
def test_random_horizontal_invalid_prob_py():
......@@ -164,7 +164,7 @@ def test_random_horizontal_invalid_prob_py():
data = data.map(input_columns=["image"], operations=transform())
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e)
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
def test_random_horizontal_comp(plot=False):
......
......@@ -190,7 +190,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
operations=[test_op]) # Add column for "annotation"
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input is not" in str(error)
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error)
def test_random_horizontal_flip_with_bbox_invalid_bounds_c():
......
......@@ -107,7 +107,7 @@ def test_random_perspective_exception_distortion_scale_range():
_ = py_vision.RandomPerspective(distortion_scale=1.5)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required range"
assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)."
def test_random_perspective_exception_prob_range():
......@@ -119,7 +119,7 @@ def test_random_perspective_exception_prob_range():
_ = py_vision.RandomPerspective(prob=1.2)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required range"
assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)."
if __name__ == "__main__":
......
......@@ -163,7 +163,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err)
assert "Input is not within the required interval of (1 to 16777216)." in str(err)
try:
# one of the size values is zero
......@@ -171,7 +171,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err)
assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err)
try:
# negative value for resize
......@@ -179,7 +179,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err)
assert "Input is not within the required interval of (1 to 16777216)." in str(err)
try:
# invalid input shape
......
......@@ -97,7 +97,7 @@ def test_random_sharpness_md5():
# define map operations
transforms = [
F.Decode(),
F.RandomSharpness((0.5, 1.5)),
F.RandomSharpness((0.1, 1.9)),
F.ToTensor()
]
transform = F.ComposeOp(transforms)
......
......@@ -141,7 +141,7 @@ def test_random_vertical_invalid_prob_c():
data = data.map(input_columns=["image"], operations=random_horizontal_op)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e)
assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
def test_random_vertical_invalid_prob_py():
......@@ -163,7 +163,7 @@ def test_random_vertical_invalid_prob_py():
data = data.map(input_columns=["image"], operations=transform())
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e)
assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
def test_random_vertical_comp(plot=False):
......
......@@ -191,7 +191,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err)
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err)
def test_random_vertical_flip_with_bbox_op_bad_c():
......
......@@ -150,7 +150,7 @@ def test_resize_with_bbox_op_invalid_c():
# invalid interpolation value
c_vision.ResizeWithBBox(400, interpolation="invalid")
except ValueError as err:
except TypeError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "interpolation" in str(err)
......
......@@ -154,7 +154,7 @@ def test_shuffle_exception_01():
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e)
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
def test_shuffle_exception_02():
......@@ -172,7 +172,7 @@ def test_shuffle_exception_02():
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e)
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
def test_shuffle_exception_03():
......@@ -190,7 +190,7 @@ def test_shuffle_exception_03():
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e)
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
def test_shuffle_exception_05():
......
......@@ -62,7 +62,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False):
logger.info("dtype of image_2: {}".format(image_2.dtype))
if plot:
visualize_list(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
visualize_list(np.array([image_1] * 10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
# The output data should be of a 4D tensor shape, a stack of 10 images.
assert len(image_2.shape) == 4
......@@ -144,7 +144,7 @@ def test_ten_crop_invalid_size_error_msg():
vision.TenCrop(0),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
error_msg = "Input is not within the required range"
error_msg = "Input is not within the required interval of (1 to 16777216)."
assert error_msg == str(info.value)
with pytest.raises(ValueError) as info:
......
......@@ -169,7 +169,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "operations" in str(e)
assert "Argument tensor_op_5 with value" \
" <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e)
assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e)
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
......@@ -209,7 +211,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)
assert "Input num_ops must be greater than 0" in str(e)
def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
......@@ -229,7 +231,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "integer" in str(e)
assert "Argument num_ops with value 2.5 is not of type (<class 'int'>,)" in str(e)
def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
......
......@@ -314,14 +314,15 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=
if len(orig) != len(aug) or not orig:
return
batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together
batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together
split_point = batch_size * plot_rows
orig, aug = np.array(orig), np.array(aug)
if len(orig) > plot_rows:
# Create batches of required size and add remainder to last batch
orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
orig = np.split(orig[:split_point], batch_size) + (
[orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added
aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
else:
orig = [orig]
......@@ -336,7 +337,8 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=
for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
cur_ix = base_ix + x
(axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row
# select plotting axes based on number of image rows on plot - else case when 1 row
(axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1])
axA.imshow(dataA["image"])
add_bounding_boxes(axA, dataA[annot_name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册