diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index f85407c1feec10a255d374f66f1460b3bfccf029..ee5a4b09fdae46a661e96efacd8736d66b009f68 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -32,7 +32,7 @@ from .validators import check_prob, check_crop, check_resize_interpolation, chec check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \ check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \ - check_mix_up, check_positive_degrees, check_uniform_augment_py + check_mix_up, check_positive_degrees, check_uniform_augment_py, check_compose_list from .utils import Inter, Border DE_PY_INTER_MODE = {Inter.NEAREST: Image.NEAREST, @@ -75,6 +75,7 @@ class ComposeOp: >>> dataset = dataset.map(input_columns="image", operations=transform()) """ + @check_compose_list def __init__(self, transforms): self.transforms = transforms diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index d6fbc4af16c7d8ca4b5e535055c694f0b829f877..20239232b5c203e6883ca8321c33f30345c36dda 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -907,3 +907,25 @@ def check_positive_degrees(method): return method(self, **kwargs) return new_method + + +def check_compose_list(method): + """Wrapper method to check the transform list of ComposeOp.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + transforms = (list(args) + [None])[0] + if "transforms" in kwargs: + transforms = kwargs.get("transforms") + if transforms is None: + raise ValueError("transforms is not provided.") + if not transforms: + raise ValueError("transforms list is empty.") + if not isinstance(transforms, list): + raise TypeError("transforms is not a python list") + + kwargs["transforms"] = transforms + + return method(self, **kwargs) + + return new_method