diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 4b020040867705fd20e99642691d4ac8ade2e14d..683102aecfca130f38009c705a5e71ef64a6c011 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2464,47 +2464,53 @@ class Schema: Parse the columns and add it to self. Args: - columns (dict or list[str]): names of columns. + columns (dict or list[dict]): dataset attribution information, decoded from schema file. + if list: columns element must be dict, 'name' and 'type' must be in keys, 'shape' optional. + if dict: columns.keys() as name, element in columns.values() is dict, and 'type' inside, 'shape' optional. + example 1) + [{'name': 'image', 'type': 'int8', 'shape': [3, 3]}, + {'name': 'label', 'type': 'int8', 'shape': [1]}] + example 2) + {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}} Raises: - RuntimeError: If failed to parse schema file. - RuntimeError: If unknown items in schema file. + RuntimeError: If failed to parse columns. + RuntimeError: If unknown items in columns. RuntimeError: If column's name field is missing. RuntimeError: If column's type field is missing. """ - if columns is None: - raise TypeError("Expected non-empty dict or string list.") self.columns = [] - for col in columns: - name = None - shape = None - data_type = None - col_details = None - if isinstance(columns, list): - col_details = col - if "name" in col: - name = col["name"] - elif isinstance(columns, dict): - col_details = columns[col] - name = col - else: - raise RuntimeError("Error parsing the schema file") - - for k, v in col_details.items(): - if k == "shape": - shape = v - elif k == "type": - data_type = v - elif k in ("t_impl", "rank"): - pass - else: - raise RuntimeError("Unknown field %s" % k) - - if name is None: - raise RuntimeError("Column's name field is missing.") - if data_type is None: - raise RuntimeError("Column's type field is missing.") - self.add_column(name, data_type, shape) + if isinstance(columns, list): + for column in columns: + try: + name = column.pop("name") + except KeyError: + raise RuntimeError("Column's name is missing") + try: + de_type = column.pop("type") + except KeyError: + raise RuntimeError("Column' type is missing") + shape = column.pop("shape", None) + column.pop("t_impl", None) + column.pop("rank", None) + if column: + raise RuntimeError("Unknown field {}".format(",".join(column.keys()))) + self.add_column(name, de_type, shape) + elif isinstance(columns, dict): + for key, value in columns.items(): + name = key + try: + de_type = value.pop("type") + except KeyError: + raise RuntimeError("Column' type is missing") + shape = value.pop("shape", None) + value.pop("t_impl", None) + value.pop("rank", None) + if value: + raise RuntimeError("Unknown field {}".format(",".join(value.keys()))) + self.add_column(name, de_type, shape) + else: + raise RuntimeError("columns must be dict or list, columns contain name, type, shape(optional).") def from_json(self, json_obj): """