From bc676fe250918b45b0c466a442a1ad744edbbb84 Mon Sep 17 00:00:00 2001 From: liyong Date: Mon, 13 Jul 2020 14:12:41 +0800 Subject: [PATCH] save op in minddataset --- .../ccsrc/minddata/dataset/api/de_pipeline.cc | 226 ++++++++++ .../ccsrc/minddata/dataset/api/de_pipeline.h | 16 + .../minddata/dataset/api/python_bindings.cc | 6 +- .../ccsrc/minddata/dataset/core/tensor.h | 10 +- .../engine/datasetops/source/mindrecord_op.cc | 2 +- .../mindrecord/include/common/shard_utils.h | 4 + .../mindrecord/include/shard_header.h | 4 + .../include/shard_index_generator.h | 2 + .../mindrecord/include/shard_writer.h | 7 + .../mindrecord/io/shard_index_generator.cc | 16 + .../minddata/mindrecord/io/shard_writer.cc | 52 +++ .../minddata/mindrecord/meta/shard_header.cc | 30 ++ mindspore/dataset/engine/datasets.py | 32 +- mindspore/dataset/engine/iterators.py | 23 ++ mindspore/dataset/engine/validators.py | 17 + tests/ut/python/dataset/test_save_op.py | 390 ++++++++++++++++++ 16 files changed, 828 insertions(+), 9 deletions(-) create mode 100644 tests/ut/python/dataset/test_save_op.py diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc index c780d8f64..0c4c6273a 100644 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc @@ -42,11 +42,17 @@ #include "minddata/dataset/util/status.h" #include "minddata/mindrecord/include/shard_category.h" #include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_writer.h" #include "pybind11/stl.h" #include "utils/log_adapter.h" namespace mindspore { namespace dataset { +using json = nlohmann::json; using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *, std::shared_ptr *); static std::unordered_map g_parse_op_func_ = { @@ -355,6 +361,226 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr &file_names, const std::string &file_type) { + Status s; + auto mr_header = std::make_shared(); + auto mr_writer = std::make_unique(); + std::vector blob_fields; + uint64_t mr_schema_id = 0; + if (mindrecord::SUCCESS != mindrecord::ShardWriter::initialize(&mr_writer, file_names)) { + RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter."); + } + + TensorRow row; + std::unordered_map column_name_id_map = + iterator_->GetColumnNameMap(); // map of column name, id + bool first_loop = true; // build schema in first loop + do { + json row_raw_data; + std::map>> row_bin_data; + { + py::gil_scoped_release gil_release; + s = iterator_->FetchNextTensorRow(&row); + } + RETURN_IF_NOT_OK(s); + if (row.empty()) break; + if (first_loop) { + json mr_json; + std::vector index_fields; + s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields); + RETURN_IF_NOT_OK(s); + mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id); + mr_writer->SetShardHeader(mr_header); + first_loop = false; + } + // construct data + if (!row.empty()) { // write data + s = FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data); + RETURN_IF_NOT_OK(s); + std::shared_ptr> output_bin_data; + mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data); + std::map> raw_data; + raw_data.insert(std::pair>(mr_schema_id, std::vector{row_raw_data})); + std::vector> bin_data; + if (nullptr != output_bin_data) { + bin_data.emplace_back(*output_bin_data); + } + mr_writer->WriteRawData(raw_data, bin_data); + } + } while (!row.empty()); + mr_writer->Commit(); + mindrecord::ShardIndexGenerator::finalize(file_names); + return Status::OK(); +} + +Status DEPipeline::FetchDataFromTensorRow(const TensorRow &row, + const std::unordered_map &column_name_id_map, + json *row_raw_data, + std::map>> *row_bin_data) { + if (row_raw_data == nullptr) { + RETURN_STATUS_UNEXPECTED("error: row raw data is NULL."); + } + if (row_bin_data == nullptr) { + RETURN_STATUS_UNEXPECTED("error: row bin data is NULL."); + } + if (column_name_id_map.empty()) { + RETURN_STATUS_UNEXPECTED("Error: column not found"); + } + Status s; + for (auto &col : column_name_id_map) { + auto idx = col.second; + auto column_name = col.first; + auto &tensor = row[idx]; + auto column_type = tensor->type(); + + std::unique_ptr> data_ptr; + if (column_type == DataType::DE_INT8) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT16) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT16) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT8) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT32) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_UINT32) { + std::unique_ptr data; + std::unique_ptr dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_INT64) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_FLOAT32) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_FLOAT64) { + std::unique_ptr data, dummy; + s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); + RETURN_IF_NOT_OK(s); + if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); + } else if (column_type == DataType::DE_STRING) { + auto buffer = tensor->GetStringsBuffer(); + std::string ss(reinterpret_cast(buffer)); // assume scalar string tensor + (*row_raw_data)[column_name] = std::move(ss); + continue; + } else { + RETURN_STATUS_UNEXPECTED("Got unexpected type when casting data."); + } + RETURN_IF_NOT_OK(s); + if (data_ptr != nullptr) { + (*row_bin_data)[column_name] = std::move(data_ptr); + } + } + return Status::OK(); +} + +template +Status DEPipeline::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, + std::unique_ptr *data, std::unique_ptr> *data_ptr, + std::unique_ptr *s, bool need_convert) { + if (nullptr == src) { + RETURN_STATUS_UNEXPECTED("Error: buffer of Tensor is NULL."); + } + *data_ptr = std::make_unique>(num_of_elements * sizeof(T)); + if (need_convert) { + auto tmp_ptr = std::make_unique>(num_of_elements * sizeof(S)); + std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin()); + auto s_ptr = reinterpret_cast(&(*(tmp_ptr->begin()))); + auto el = std::make_unique(); + for (uint32_t i = 0; i < num_of_elements; ++i) { + *el = *(s_ptr + i); + auto t_ptr = reinterpret_cast(el.get()); + for (uint32_t j = 0; j < sizeof(T); ++j) { + *((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j); + } + } + } else { + std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin()); + } + if (shape.empty()) { + *data = std::make_unique(); + auto t_ptr = reinterpret_cast((*data).get()); + for (uint32_t i = 0; i < sizeof(T); ++i) { + *(t_ptr + i) = *((*data_ptr)->begin() + i); + } + } + return Status::OK(); +} + +Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map &column_name_id_map, + const TensorRow &row, json *schema, std::vector *index_fields) { + if (schema == nullptr) { + RETURN_STATUS_UNEXPECTED("error: schema is NULL."); + } + if (index_fields == nullptr) { + RETURN_STATUS_UNEXPECTED("error: index fields is NULL."); + } + if (column_name_id_map.empty()) { + RETURN_STATUS_UNEXPECTED("Error: column not found."); + } + for (auto &col : column_name_id_map) { + auto idx = col.second; + auto column_name = col.first; + auto &tensor = row[idx]; + auto column_type = tensor->type(); + auto column_shape = tensor->shape(); + + std::string mr_type; + auto shapes = column_shape.AsVector(); + std::vector mr_shape(shapes.begin(), shapes.end()); + std::string el = column_type.ToString(); + if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) { + std::string err_msg("Error: can not support data type: " + el); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + mr_type = mindrecord::kTypesMap.at(el); + } + if (mr_shape.empty()) { + if (mr_type == "bytes") { // map to int32 when bytes without shape. + mr_type == "int32"; + } + (*schema)[column_name] = {{"type", mr_type}}; + } else { + if (mr_type == "string") { // mindrecord can not support string with shape. + std::string err_msg("Error: mindrecord can not support multi-dimensional string tensor."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (mr_type == "bytes") { // ignore shape of bytes in minrecord + (*schema)[column_name] = {{"type", mr_type}}; + } else { + (*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}}; + } + } + if (mr_type == "bytes" || !mr_shape.empty()) continue; + index_fields->emplace_back(column_name); // candidate of index fields + } + return Status::OK(); +} Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, std::vector> *operators, int num_padded) { diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h index 755e827ef..b3adb6ae9 100644 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h @@ -17,6 +17,7 @@ #define DATASET_API_DE_PIPELINE_H_ #include +#include #include #include #include @@ -33,6 +34,7 @@ namespace py = pybind11; namespace mindspore { namespace dataset { +using json = nlohmann::json; using DsOpPtr = std::shared_ptr; class CacheClient; @@ -100,6 +102,8 @@ class DEPipeline { Status GetOutputTypes(py::list *output); + Status SaveDataset(const std::vector &file_names, const std::string &file_type); + int GetDatasetSize() const; int GetBatchSize() const; @@ -110,6 +114,18 @@ class DEPipeline { Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + template + Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, + std::unique_ptr *data, std::unique_ptr> *data_ptr, + std::unique_ptr *s, bool need_convert = false); + + Status FetchMetaFromTensorRow(const std::unordered_map &column_name_id_map, + const TensorRow &row, json *schema, std::vector *index_fields); + + Status FetchDataFromTensorRow(const TensorRow &row, + const std::unordered_map &column_name_id_map, json *row_raw_data, + std::map>> *row_bin_data); + Status BuildMindrecordSamplerChain(const py::handle &handle, std::vector> *operators, int num_padded); diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc index 173c1af2f..b880c0cc4 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -184,7 +184,11 @@ void bindDEPipeline(py::module *m) { .def("GetDatasetSize", &DEPipeline::GetDatasetSize) .def("GetBatchSize", &DEPipeline::GetBatchSize) .def("GetNumClasses", &DEPipeline::GetNumClasses) - .def("GetRepeatCount", &DEPipeline::GetRepeatCount); + .def("GetRepeatCount", &DEPipeline::GetRepeatCount) + .def("SaveDataset", [](DEPipeline &de, const std::vector &file_names, const std::string &file_type) { + THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); + return true; + }); } void bindDatasetOps(py::module *m) { (void)py::class_>(*m, "TFReaderOp") diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.h b/mindspore/ccsrc/minddata/dataset/core/tensor.h index b0b173e9c..8707cbd7c 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.h @@ -312,6 +312,11 @@ class Tensor { // @return const unsigned char* const unsigned char *GetBuffer() const; + // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the + // tensor's type is a string, otherwise undefined address would be returned. + // @return address of the first string of the tensor. + uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } + // Getter of the type // @return DataType type() const { return type_; } @@ -643,11 +648,6 @@ class Tensor { // @return length of the string Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; - // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the - // tensor's type is a string, otherwise undefined address would be returned. - // @return address of the first string of the tensor. - uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } - // all access to shape_ should be via shape TensorShape shape_; // data type of tensor diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index cf1493eb7..0886f7514 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -215,7 +215,7 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info ParallelOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\n Dataset file : "; + out << "\nDataset file : "; for (auto &file : dataset_file_) { out << file << " "; } diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h index bd1cda8a9..6c3e4e9c6 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h @@ -137,6 +137,10 @@ const std::set kScalarFieldTypeSet = {"string", "int32", "int64", " // number field list const std::set kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; +const std::unordered_map kTypesMap = { + {"bool", "int32"}, {"int8", "int32"}, {"uint8", "bytes"}, {"int16", "int32"}, + {"uint16", "int32"}, {"int32", "int32"}, {"uint32", "int64"}, {"int64", "int64"}, + {"float16", "float32"}, {"float32", "float32"}, {"float64", "float64"}, {"string", "string"}}; /// \brief split a string using a character /// \param[in] field target string /// \param[in] separator a character for spliting diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h index 67169e869..008f37941 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h @@ -124,6 +124,10 @@ class ShardHeader { MSRStatus FileToPages(const std::string dump_file_name); + static MSRStatus initialize(const std::shared_ptr *header_ptr, const json &schema, + const std::vector &index_fields, std::vector &blob_fields, + uint64_t &schema_id); + private: MSRStatus InitializeHeader(const std::vector &headers, bool load_dataset); diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h index fb85d9adb..c05b8876e 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h @@ -57,6 +57,8 @@ class ShardIndexGenerator { /// \brief create databases for indexes MSRStatus WriteToDatabase(); + static MSRStatus finalize(const std::vector file_names); + private: static int Callback(void *not_used, int argc, char **argv, char **az_col_name); diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h index 833928773..67d4e471f 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h @@ -108,6 +108,13 @@ class ShardWriter { std::map> &blob_data, bool sign = true, bool parallel_writer = false); + MSRStatus MergeBlobData(const std::vector &blob_fields, + const std::map>> &row_bin_data, + std::shared_ptr> *output); + + static MSRStatus initialize(const std::unique_ptr *writer_ptr, + const std::vector &file_names); + private: /// \brief write shard header data to disk MSRStatus WriteShardHeader(); diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc index f9b18a3bf..5b102c396 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc @@ -622,5 +622,21 @@ void ShardIndexGenerator::DatabaseWriter() { shard_no = task_++; } } +MSRStatus ShardIndexGenerator::finalize(const std::vector file_names) { + if (file_names.empty()) { + MS_LOG(ERROR) << "Mindrecord files is empty."; + return FAILED; + } + ShardIndexGenerator sg{file_names[0]}; + if (SUCCESS != sg.Build()) { + MS_LOG(ERROR) << "Failed to build index generator."; + return FAILED; + } + if (SUCCESS != sg.WriteToDatabase()) { + MS_LOG(ERROR) << "Failed to write to database."; + return FAILED; + } + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc index e85229cc3..2f2aebf7f 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc @@ -637,6 +637,42 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map *row_count = std::get<2>(v); return SUCCESS; } +MSRStatus ShardWriter::MergeBlobData(const std::vector &blob_fields, + const std::map>> &row_bin_data, + std::shared_ptr> *output) { + if (blob_fields.empty()) { + return SUCCESS; + } + if (blob_fields.size() == 1) { + auto &blob = row_bin_data.at(blob_fields[0]); + auto blob_size = blob->size(); + *output = std::make_shared>(blob_size); + std::copy(blob->begin(), blob->end(), (*output)->begin()); + } else { + size_t output_size = 0; + for (auto &field : blob_fields) { + output_size += row_bin_data.at(field)->size(); + } + output_size += blob_fields.size() * sizeof(uint64_t); + *output = std::make_shared>(output_size); + std::vector buf(sizeof(uint64_t), 0); + size_t idx = 0; + for (auto &field : blob_fields) { + auto &blob = row_bin_data.at(field); + uint64_t blob_size = blob->size(); + // big edian + for (size_t i = 0; i < buf.size(); ++i) { + buf[buf.size() - 1 - i] = std::numeric_limits::max() & blob_size; + blob_size >>= 8u; + } + std::copy(buf.begin(), buf.end(), (*output)->begin() + idx); + idx += buf.size(); + std::copy(blob->begin(), blob->end(), (*output)->begin() + idx); + idx += blob->size(); + } + } + return SUCCESS; +} MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, std::vector> &blob_data, bool sign, bool parallel_writer) { @@ -1250,5 +1286,21 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr &la last_blob_page = page.first; } } + +MSRStatus ShardWriter::initialize(const std::unique_ptr *writer_ptr, + const std::vector &file_names) { + if (nullptr == writer_ptr) { + MS_LOG(ERROR) << "ShardWriter pointer is NULL."; + return FAILED; + } + auto res = (*writer_ptr)->Open(file_names, false); + if (SUCCESS != res) { + MS_LOG(ERROR) << "Failed to open mindrecord files to writer."; + return FAILED; + } + (*writer_ptr)->SetHeaderSize(1 << 24); + (*writer_ptr)->SetPageSize(1 << 25); + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc index 500037399..843b412a3 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc @@ -721,5 +721,35 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { page_in_handle.close(); return SUCCESS; } + +MSRStatus ShardHeader::initialize(const std::shared_ptr *header_ptr, const json &schema, + const std::vector &index_fields, std::vector &blob_fields, + uint64_t &schema_id) { + if (nullptr == header_ptr) { + MS_LOG(ERROR) << "ShardHeader pointer is NULL."; + return FAILED; + } + auto schema_ptr = Schema::Build("mindrecord", schema); + if (nullptr == schema_ptr) { + MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema."; + return FAILED; + } + schema_id = (*header_ptr)->AddSchema(schema_ptr); + // create index + std::vector> id_index_fields; + if (!index_fields.empty()) { + for (auto &el : index_fields) { + id_index_fields.emplace_back(schema_id, el); + } + if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) { + MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index."; + return FAILED; + } + } + + auto build_schema_ptr = (*header_ptr)->GetSchemas()[0]; + blob_fields = build_schema_ptr->GetBlobFields(); + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 846e7e0a5..f3136cefa 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -38,13 +38,13 @@ from mindspore._c_expression import typing from mindspore import log as logger from . import samplers -from .iterators import DictIterator, TupleIterator, DummyIterator +from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 + check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -1044,6 +1044,34 @@ class Dataset: return TransferDataset(self, queue_name, device_id, device_type, num_batch) + @check_save + def save(self, file_name, num_files=1, file_type='mindrecord'): + """ + Save the dynamic data processed by dataset pipeline as common dataset format, support: mindrecord. + + Note: + 1. To save the samples in order, should set dataset's shuffle false and num_files 1. + 2. Before call the function, do not use batch, repeat operator or data augmentation operators + with random attribute in map operator. + 3. Mindreocrd do not support np.uint64, multi-dimensional np.uint8(drop dimension) and + multi-dimensional string. + + Args: + file_name (str): Path to dataset file. + num_files (int, optional): Number of dataset files.(default=1). + file_type (str, optional): dataset format.(default='mindrecord') + + """ + + if num_files == 1: + file_names = [file_name] + else: + suffix = len(str(num_files - 1)) + file_names = ["{}{}".format(file_name, str(x).rjust(suffix, '0')) + for x in range(num_files)] + + return SaveOp(self).save(file_names, file_type) + def create_tuple_iterator(self, columns=None): """ Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index a2a23cbb4..45da97184 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -173,6 +173,7 @@ class Iterator: # Convert python node into C node and add to C layer execution tree in postorder traversal. def __convert_node_postorder(self, node): + self.check_node_type(node) op_type = self.__get_dataset_type(node) c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args()) @@ -224,6 +225,10 @@ class Iterator: self._index += 1 return data + @abstractmethod + def check_node_type(self, node): + pass + def get_output_shapes(self): return [t for t in self.depipeline.GetOutputShapes()] @@ -245,11 +250,27 @@ class Iterator: def __deepcopy__(self, memo): return self +class SaveOp(Iterator): + """ + The derived class of Iterator with dict type. + """ + def get_next(self): + pass + + def check_node_type(self, node): + if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)): + logger.warning("Used shuffle, repeat, batch before save operator.") + + def save(self, file_names, file_type): + return self.depipeline.SaveDataset(file_names, file_type) + class DictIterator(Iterator): """ The derived class of Iterator with dict type. """ + def check_node_type(self, node): + pass def __iter__(self): return self @@ -269,6 +290,8 @@ class TupleIterator(Iterator): """ The derived class of Iterator with list type. """ + def check_node_type(self, node): + pass def __init__(self, dataset, columns=None): if columns is not None: diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 29904f1a9..c61630a03 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -246,7 +246,24 @@ def check_celebadataset(method): return new_method +def check_save(method): + """A wrapper that wrap a parameter checker to the save op.""" + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + nreq_param_int = ['num_files'] + nreq_param_str = ['file_name', 'file_type'] + validate_dataset_param_value(nreq_param_int, param_dict, int) + if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): + raise ValueError("num_files should between {} and {}.".format(1, 1000)) + validate_dataset_param_value(nreq_param_str, param_dict, str) + if param_dict.get('file_type') != 'mindrecord': + raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type'))) + return method(self, *args, **kwargs) + + return new_method def check_minddataset(method): """A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" diff --git a/tests/ut/python/dataset/test_save_op.py b/tests/ut/python/dataset/test_save_op.py new file mode 100644 index 000000000..2ed326276 --- /dev/null +++ b/tests/ut/python/dataset/test_save_op.py @@ -0,0 +1,390 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This is the test module for saveOp. +""" +import os +import mindspore.dataset as ds +from mindspore import log as logger +from mindspore.mindrecord import FileWriter +import numpy as np +import pytest + +CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" +CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" + +FILES_NUM = 1 +num_readers = 1 + + +@pytest.fixture(name="add_and_remove_cv_file") +def fixture_remove(): + """add/remove cv file""" + if os.path.exists("{}".format(CV_FILE_NAME1)): + os.remove("{}".format(CV_FILE_NAME1)) + if os.path.exists("{}.db".format(CV_FILE_NAME1)): + os.remove("{}.db".format(CV_FILE_NAME1)) + + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + yield "yield_cv_data" + if os.path.exists("{}".format(CV_FILE_NAME1)): + os.remove("{}".format(CV_FILE_NAME1)) + if os.path.exists("{}.db".format(CV_FILE_NAME1)): + os.remove("{}.db".format(CV_FILE_NAME1)) + + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + + +def test_case_00(add_and_remove_cv_file): # only bin data + data = [{"image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8')}, + {"image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8')}, + {"image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8')}, + {"image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8')}, + {"image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8')}] + schema = { + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "image3": {"type": "bytes"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}} + writer = FileWriter(CV_FILE_NAME1, FILES_NUM) + writer.add_schema(schema, "schema") + writer.write_raw_data(data) + writer.commit() + + d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) + d1.save(CV_FILE_NAME2, FILES_NUM) + data_value_to_list = [] + + for item in data: + new_data = {} + new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) + new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) + new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) + new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) + new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) + data_value_to_list.append(new_data) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + assert d2.get_dataset_size() == 5 + num_iter = 0 + for item in d2.create_dict_iterator(): + assert len(item) == 5 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 5 + + +def test_case_01(add_and_remove_cv_file): # only raw data + data = [{"file_name": "001.jpg", "label": 43}, + {"file_name": "002.jpg", "label": 91}, + {"file_name": "003.jpg", "label": 61}, + {"file_name": "004.jpg", "label": 29}, + {"file_name": "005.jpg", "label": 78}, + {"file_name": "006.jpg", "label": 37}] + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"} + } + + writer = FileWriter(CV_FILE_NAME1, FILES_NUM) + writer.add_schema(schema, "schema") + writer.write_raw_data(data) + writer.commit() + + d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) + d1.save(CV_FILE_NAME2, FILES_NUM) + + data_value_to_list = [] + for item in data: + new_data = {} + new_data['file_name'] = np.asarray(item["file_name"], dtype='S') + new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) + data_value_to_list.append(new_data) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + assert d2.get_dataset_size() == 6 + num_iter = 0 + for item in d2.create_dict_iterator(): + logger.info(item) + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + +def test_case_02(add_and_remove_cv_file): # muti-bytes + data = [{"file_name": "001.jpg", "label": 43, + "float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12345, + "float64": 1987654321.123456785, + "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, + "float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12445, + "float64": 1987654321.123456786, + "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, + "float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12545, + "float64": 1987654321.123456787, + "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, + "float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12645, + "float64": 1987654321.123456788, + "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image4 bytes abc", encoding='UTF-8'), + "image2": bytes("image4 bytes def", encoding='UTF-8'), + "image3": bytes("image4 bytes ghi", encoding='UTF-8'), + "image4": bytes("image4 bytes jkl", encoding='UTF-8'), + "image5": bytes("image4 bytes mno", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, + "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12745, + "float64": 1987654321.123456789, + "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, + "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), + "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, + 123414314.2141243, 87.1212122], dtype=np.float64), + "float32": 3456.12745, + "float64": 1987654321.123456789, + "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int32), + "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8')} + ] + schema = {"file_name": {"type": "string"}, + "float32_array": {"type": "float32", "shape": [-1]}, + "float64_array": {"type": "float64", "shape": [-1]}, + "float32": {"type": "float32"}, + "float64": {"type": "float64"}, + "source_sos_ids": {"type": "int32", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "image3": {"type": "bytes"}, + "label": {"type": "int32"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}} + writer = FileWriter(CV_FILE_NAME1, FILES_NUM) + writer.add_schema(schema, "schema") + writer.write_raw_data(data) + writer.commit() + + d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False) + d1.save(CV_FILE_NAME2, FILES_NUM) + data_value_to_list = [] + + for item in data: + new_data = {} + new_data['file_name'] = np.asarray(item["file_name"], dtype='S') + new_data['float32_array'] = item["float32_array"] + new_data['float64_array'] = item["float64_array"] + new_data['float32'] = item["float32"] + new_data['float64'] = item["float64"] + new_data['source_sos_ids'] = item["source_sos_ids"] + new_data['source_sos_mask'] = item["source_sos_mask"] + new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) + new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) + new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) + new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) + new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) + new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) + data_value_to_list.append(new_data) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + assert d2.get_dataset_size() == 6 + num_iter = 0 + for item in d2.create_dict_iterator(): + assert len(item) == 13 + for field in item: + if isinstance(item[field], np.ndarray): + if item[field].dtype == np.float32: + assert (item[field] == + np.array(data_value_to_list[num_iter][field], np.float32)).all() + else: + assert (item[field] == + data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + +def generator_1d(): + for i in range(10): + yield (np.array([i]),) + + +def test_case_03(add_and_remove_cv_file): + + # apply dataset operations + d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) + + d1.save(CV_FILE_NAME2) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + + i = 0 + for item in d2.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(item["data"], golden) + i = i + 1 + + +def generator_with_type(t): + for i in range(64): + yield (np.array([i], dtype=t),) + + +def type_tester(t): + logger.info("Test with Type {}".format(t.__name__)) + + # apply dataset operations + data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], shuffle=False) + + data1 = data1.batch(4) + + data1 = data1.repeat(3) + + data1.save(CV_FILE_NAME2) + + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + + i = 0 + num_repeat = 0 + for item in d2.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) + logger.info(item) + assert np.array_equal(item["data"], golden) + i = i + 4 + if i == 64: + i = 0 + num_repeat += 1 + assert num_repeat == 3 + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + + +def test_case_04(): + # uint8 will drop shape as mindrecord store uint8 as bytes + types = [np.int8, np.int16, np.int32, np.int64, + np.uint16, np.uint32, np.float32, np.float64] + + for t in types: + type_tester(t) + + +def test_case_05(add_and_remove_cv_file): + + d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) + + with pytest.raises(Exception, match="num_files should between 1 and 1000."): + d1.save(CV_FILE_NAME2, 0) + + +def test_case_06(add_and_remove_cv_file): + + d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False) + + with pytest.raises(Exception, match="tfrecord dataset format is not supported."): + d1.save(CV_FILE_NAME2, 1, "tfrecord") -- GitLab