diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index 4056825ff3434824d7987ec44ad7965950edb281..90bca480382507b1e414e4414e74fbbb2b2274c8 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -26,8 +26,7 @@ from .shardheader import ShardHeader from .shardindexgenerator import ShardIndexGenerator from .shardutils import MIN_SHARD_COUNT, MAX_SHARD_COUNT, VALID_ATTRIBUTES, VALID_ARRAY_ATTRIBUTES, \ check_filename, VALUE_TYPE_MAP -from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError, \ - MRMValidateDataError +from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError __all__ = ['FileWriter'] @@ -201,52 +200,13 @@ class FileWriter: raw_data.pop(i) logger.warning(v) - def _verify_based_on_blob_fields(self, raw_data): + def write_raw_data(self, raw_data): """ - Verify data according to blob fields which is sub set of schema's fields. - - Raise exception if validation failed. - 1) allowed data type contains: "int32", "int64", "float32", "float64", "string", "bytes". - - Args: - raw_data (list[dict]): List of raw data. - - Raises: - MRMValidateDataError: If data does not match blob fields. - """ - schema_content = self._header.schema - for field in schema_content: - for i, v in enumerate(raw_data): - if field not in v: - raise MRMValidateDataError("for schema, {} th data is wrong: "\ - "there is not '{}' object in the raw data.".format(i, field)) - if field in self._header.blob_fields: - field_type = type(v[field]).__name__ - if field_type not in VALUE_TYPE_MAP: - raise MRMValidateDataError("for schema, {} th data is wrong: "\ - "data type for '{}' is not matched.".format(i, field)) - if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: - raise MRMValidateDataError("for schema, {} th data is wrong: "\ - "data type for '{}' is not matched.".format(i, field)) - if field_type == 'ndarray': - if 'shape' not in schema_content[field]: - raise MRMValidateDataError("for schema, {} th data is wrong: " \ - "data type for '{}' is not matched.".format(i, field)) - try: - # tuple or list - np.reshape(v[field], schema_content[field]['shape']) - except ValueError: - raise MRMValidateDataError("for schema, {} th data is wrong: " \ - "data type for '{}' is not matched.".format(i, field)) - - def write_raw_data(self, raw_data, validate=True): - """ - Write raw data and generate sequential pair of MindRecord File. + Write raw data and generate sequential pair of MindRecord File and \ + validate data based on predefined schema by default. Args: raw_data (list[dict]): List of raw data. - validate (bool, optional): Validate data according schema if it equals to True, - or validate data according to blob fields (default=True). Raises: ParamTypeError: If index field is invalid. @@ -264,11 +224,8 @@ class FileWriter: for each_raw in raw_data: if not isinstance(each_raw, dict): raise ParamTypeError('raw_data item', 'dict') - if validate is True: - self._verify_based_on_schema(raw_data) - elif validate is False: - self._verify_based_on_blob_fields(raw_data) - return self._writer.write_raw_data(raw_data, validate) + self._verify_based_on_schema(raw_data) + return self._writer.write_raw_data(raw_data, True) def set_header_size(self, header_size): """