From aa3f89e74f16f6de1efc92b4ac7ae66f83b05652 Mon Sep 17 00:00:00 2001 From: liyong Date: Fri, 8 May 2020 16:11:59 +0800 Subject: [PATCH] mindrecord support read file list --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 9 +- .../ccsrc/dataset/api/python_bindings.cc | 21 +-- .../engine/datasetops/source/mindrecord_op.cc | 33 +++-- .../engine/datasetops/source/mindrecord_op.h | 30 +++-- .../ccsrc/mindrecord/common/shard_error.cc | 3 + .../ccsrc/mindrecord/common/shard_pybind.cc | 6 +- .../ccsrc/mindrecord/include/shard_error.h | 3 +- .../ccsrc/mindrecord/include/shard_header.h | 13 +- .../ccsrc/mindrecord/include/shard_reader.h | 21 +-- .../mindrecord/io/shard_index_generator.cc | 18 ++- mindspore/ccsrc/mindrecord/io/shard_reader.cc | 93 ++++++++----- mindspore/ccsrc/mindrecord/io/shard_writer.cc | 23 +++- .../ccsrc/mindrecord/meta/shard_header.cc | 56 ++++---- mindspore/dataset/engine/datasets.py | 14 +- mindspore/dataset/engine/validators.py | 7 +- mindspore/mindrecord/filereader.py | 9 +- mindspore/mindrecord/mindpage.py | 9 +- mindspore/mindrecord/shardreader.py | 9 +- mindspore/mindrecord/shardsegment.py | 9 +- tests/ut/cpp/dataset/mind_record_op_test.cc | 21 ++- .../cpp/mindrecord/ut_shard_operator_test.cc | 34 ++--- .../ut/cpp/mindrecord/ut_shard_reader_test.cc | 18 +-- .../cpp/mindrecord/ut_shard_segment_test.cc | 10 +- .../ut/cpp/mindrecord/ut_shard_writer_test.cc | 10 +- tests/ut/python/dataset/test_minddataset.py | 123 +++++++++++++++++- .../dataset/test_minddataset_exception.py | 57 ++++++++ .../python/mindrecord/test_mindrecord_base.py | 10 ++ 27 files changed, 496 insertions(+), 173 deletions(-) diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 4a5dac198..27e5f8d2b 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr builder = std::make_shared(); - (void)builder->SetDatasetFile(ToString(args["dataset_file"])); - + bool load_dataset = ToBool(args["load_dataset"]); + if (load_dataset == true) { + (void)builder->SetDatasetFile({ToString(args["dataset_file"])}); + } else { + (void)builder->SetDatasetFile(ToStringVector(args["dataset_file"])); + } + (void)builder->SetLoadDataset(load_dataset); std::vector in_col_names; if (!args["columns_list"].is_none()) { in_col_names = ToStringVector(args["columns_list"]); diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 8e942e1d9..efb04a3bd 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) { }); (void)py::class_>(*m, "MindRecordOp") - .def_static("get_num_rows", [](const std::string &path, const py::object &sampler) { - int64_t count = 0; - std::shared_ptr op; - if (py::hasattr(sampler, "_create_for_minddataset")) { - auto create = sampler.attr("_create_for_minddataset"); - op = create().cast>(); - } - THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count)); - return count; - }); + .def_static("get_num_rows", + [](const std::vector &paths, bool load_dataset, const py::object &sampler) { + int64_t count = 0; + std::shared_ptr op; + if (py::hasattr(sampler, "_create_for_minddataset")) { + auto create = sampler.attr("_create_for_minddataset"); + op = create().cast>(); + } + THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count)); + return count; + }); (void)py::class_>(*m, "ManifestOp") .def_static("get_num_rows_and_classes", diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index bd6f03828..1644ce1cf 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -40,7 +40,7 @@ using mindrecord::ShardOperator; using mindrecord::ShardReader; // Builder constructor. Creates the builder object. -MindRecordOp::Builder::Builder() : build_dataset_file_("") { +MindRecordOp::Builder::Builder() : build_dataset_file_({}) { // Some arguments to the MindRecordOp constructor have a default argument that is taken // from the client config. // The user may choose to change these values for the construction of the StorageOp by @@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { "Building a MindRecordOp that has not provided a file."); } - new_mind_record_op = std::make_shared(build_num_mind_record_workers_, build_rows_per_buffer_, - build_dataset_file_, build_op_connector_queue_size_, - build_columns_to_load_, build_operators_, build_block_reader_); + new_mind_record_op = std::make_shared( + build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_, + build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_); RETURN_IF_NOT_OK(new_mind_record_op->Init()); @@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } // Constructor of the MindRecordOp. -MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file, - int32_t op_connector_queue_size, const std::vector &columns_to_load, +MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, + std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, + const std::vector &columns_to_load, const std::vector> &operators, const bool &block_reader) : ParallelOp(num_mind_record_workers, op_connector_queue_size), rows_per_buffer_(rows_per_buffer), dataset_file_(dataset_file), + load_dataset_(load_dataset), columns_to_load_(columns_to_load), operators_(operators), num_mind_record_workers_(num_mind_record_workers), @@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf // Private helper method to encapsulate some common construction/reset tasks Status MindRecordOp::Init() { shard_reader_ = std::make_unique(); - auto rc = shard_reader_->Open(dataset_file_, num_mind_record_workers_, columns_to_load_, operators_, block_reader_); + auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, + block_reader_); - CHECK_FAIL_RETURN_UNEXPECTED(rc != MSRStatus::FAILED, + CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); data_schema_ = std::make_unique(); @@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info ParallelOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\n1 Dataset file : " << dataset_file_ << "\nNumber of rows : " << num_rows_ - << "\nRows per buffer : " << rows_per_buffer_ << "\nNumber of buffers : " << buffers_needed_ + out << "\n Dataset file : "; + for (auto &file : dataset_file_) { + out << file << " "; + } + out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_ + << "\nNumber of buffers : " << buffers_needed_ << "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n"; } } @@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() { return Status::OK(); } -Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr &op, - int64_t *count) { +Status MindRecordOp::CountTotalRows(const std::vector dataset_path, bool load_dataset, + const std::shared_ptr &op, int64_t *count) { std::unique_ptr shard_reader = std::make_unique(); - MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count); + MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count); if (rc == MSRStatus::FAILED) { RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h index fbf3d8f98..a3a066bfc 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h @@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp { return *this; } - Builder &SetDatasetFile(const std::string &file) { - build_dataset_file_ = file; + Builder &SetDatasetFile(const std::vector &files) { + build_dataset_file_ = files; return *this; } @@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp { return *this; } + Builder &SetLoadDataset(bool load_dataset) { + build_load_dataset_ = load_dataset; + return *this; + } + Status SanityCheck() const; static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } @@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp { int32_t builder_num_workers_; int32_t build_rows_per_buffer_; int32_t build_op_connector_queue_size_; - std::string build_dataset_file_; + std::vector build_dataset_file_; + bool build_load_dataset_; std::vector build_columns_to_load_; std::vector> build_operators_; bool build_block_reader_; @@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp { // @note The builder class should be used to call it // @param num_mind_record_workers - The number of workers for the op (run by ShardReader) // @param rows_per_buffer - The requested number of rows per buffer - // @param dataset_file - A shard file + // @param dataset_file - dataset files // @param op_connector_queue_size - The output connector queue size // @param columns_to_load - The list of columns to use (column name) // @param operators - ShardOperators for Shuffle, Category, Sample - MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file, - int32_t op_connector_queue_size, const std::vector &columns_to_load, + MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, + bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, const std::vector> &operators, const bool &block_reader); // Destructor @@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp { // Getter method int32_t num_rows() const { return num_rows_; } - // Getter method - static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr &op, - int64_t *count); + static Status CountTotalRows(const std::vector dataset_path, bool load_dataset, + const std::shared_ptr &op, int64_t *count); // Getter method int32_t rows_per_buffer() const { return rows_per_buffer_; } // Getter method - std::string dataset_file() const { return dataset_file_; } + std::vector dataset_file() const { return dataset_file_; } // Getter method std::vector columns_to_load() const { return columns_to_load_; } bool block_reader() const { return block_reader_; } + bool load_dataset() const { return load_dataset_; } + Status Init(); Status SetColumnsBlob(); @@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp { Status FetchBlockBuffer(const int32_t &buffer_id); int32_t rows_per_buffer_; // The number of requested rows per buffer. - std::string dataset_file_; // A dataset file + std::vector dataset_file_; // dataset files + bool load_dataset_; // load dataset from single file or not std::vector columns_to_load_; // Columns to load from dataset std::vector> operators_; // ShardOperators to use int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader diff --git a/mindspore/ccsrc/mindrecord/common/shard_error.cc b/mindspore/ccsrc/mindrecord/common/shard_error.cc index cf43dcb31..ad68aaf92 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_error.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_error.cc @@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) { case IO_FAILED: return "io operate failed"; break; + case MATCH_HEADER_FAILED: + return "match header failed"; + break; default: return "invalid error no"; } diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 0eb9ac14b..0391ee5e1 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) { void BindShardReader(const py::module *m) { (void)py::class_>(*m, "ShardReader", py::module_local()) .def(py::init<>()) - .def("open", (MSRStatus(ShardReader::*)(const std::string &, const int &, const std::vector &, + .def("open", (MSRStatus(ShardReader::*)(const std::vector &, bool, const int &, + const std::vector &, const std::vector> &)) & ShardReader::OpenPy) .def("launch", &ShardReader::Launch) @@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) { void BindShardSegment(py::module *m) { (void)py::class_(*m, "ShardSegment", py::module_local()) .def(py::init<>()) - .def("open", (MSRStatus(ShardSegment::*)(const std::string &, const int &, const std::vector &, + .def("open", (MSRStatus(ShardSegment::*)(const std::vector &, bool, const int &, + const std::vector &, const std::vector> &)) & ShardSegment::OpenPy) .def("get_category_fields", diff --git a/mindspore/ccsrc/mindrecord/include/shard_error.h b/mindspore/ccsrc/mindrecord/include/shard_error.h index b85eeb71c..8488ca70c 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_error.h +++ b/mindspore/ccsrc/mindrecord/include/shard_error.h @@ -72,7 +72,8 @@ enum MSRStatus { ILLEGAL_PARAMETERS, GET_PAGE_BY_GROUP_ID_FAILED, GET_SYSTEM_STATE_FAILED, - IO_FAILED + IO_FAILED, + MATCH_HEADER_FAILED }; // convert error no to string message diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index d2c2ef0a2..0f2473e91 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -35,10 +35,11 @@ class ShardHeader { public: ShardHeader(); - MSRStatus Build(const std::string &file_path); - ~ShardHeader() = default; + MSRStatus BuildDataset(const std::vector &file_paths, bool load_dataset = true); + + static std::pair BuildSingleHeader(const std::string &file_path); /// \brief add the schema and save it /// \param[in] schema the schema needs to be added /// \return the last schema's id @@ -126,7 +127,7 @@ class ShardHeader { MSRStatus FileToPages(const std::string dump_file_name); private: - MSRStatus InitializeHeader(const std::vector &headers); + MSRStatus InitializeHeader(const std::vector &headers, bool load_dataset); /// \brief get the headers from all the shard data /// \param[in] the shard data real path @@ -137,9 +138,9 @@ class ShardHeader { MSRStatus ValidateField(const std::vector &field_name, json schema, const uint64_t &schema_id); /// \brief check the binary file status - MSRStatus CheckFileStatus(const std::string &path); + static MSRStatus CheckFileStatus(const std::string &path); - std::pair ValidateHeader(const std::string &path); + static std::pair ValidateHeader(const std::string &path); void ParseHeader(const json &header); @@ -149,7 +150,7 @@ class ShardHeader { MSRStatus CheckIndexField(const std::string &field, const json &schema); - void ParsePage(const json &page); + void ParsePage(const json &page, int shard_index, bool load_dataset); MSRStatus ParseStatistics(const json &statistics); diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 840f5a1b4..8ce9e3fdf 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -68,23 +68,25 @@ class ShardReader { virtual ~ShardReader(); /// \brief open files and initialize reader, c++ API - /// \param[in] file_path the path of ONE file, any file in dataset is fine + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not /// \param[in] n_consumer number of threads when reading /// \param[in] selected_columns column list to be populated /// \param[in] operators operators applied to data, operator type is shuffle, sample or category /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode /// \return MSRStatus the status of MSRStatus - MSRStatus Open(const std::string &file_path, int n_consumer = 4, + MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, const std::vector &selected_columns = {}, const std::vector> &operators = {}, const bool &block_reader = false); /// \brief open files and initialize reader, python API - /// \param[in] file_path the path of ONE file, any file in dataset is fine + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not /// \param[in] n_consumer number of threads when reading /// \param[in] selected_columns column list to be populated /// \param[in] operators operators applied to data, operator type is shuffle, sample or category /// \return MSRStatus the status of MSRStatus - MSRStatus OpenPy(const std::string &file_path, const int &n_consumer = 4, + MSRStatus OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer = 4, const std::vector &selected_columns = {}, const std::vector> &operators = {}); @@ -114,11 +116,13 @@ class ShardReader { int GetShardCount() const; /// \brief get the number of rows in database - /// \param[in] file_path the path of ONE file, any file in dataset is fine + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not /// \param[in] op smart pointer refer to ShardCategory or ShardSample object /// \param[out] count # of rows /// \return MSRStatus the status of MSRStatus - MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr &op, int64_t *count); + MSRStatus CountTotalRows(const std::vector &file_paths, bool load_dataset, + const std::shared_ptr &op, int64_t *count); /// \brief shuffle task with incremental seed /// \return void @@ -220,7 +224,7 @@ class ShardReader { std::vector> &column_values); /// \brief initialize reader - MSRStatus Init(const std::string &file_path); + MSRStatus Init(const std::vector &file_paths, bool load_dataset); /// \brief validate column list MSRStatus CheckColumnList(const std::vector &selected_columns); @@ -292,8 +296,9 @@ class ShardReader { void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); /// \brief get number of classes - int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); + int64_t GetNumClasses(const std::string &category_field); + std::pair> GetMeta(const std::string &file_path, json &meta_data); /// \brief get exactly blob fields data by indices std::vector ExtractBlobFieldBySelectColumns(std::vector &blob_fields_bytes, std::vector &ordered_selected_columns_index); diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index d831a9331..905968e3a 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe write_success_(true) {} MSRStatus ShardIndexGenerator::Build() { + auto ret = ShardHeader::BuildSingleHeader(file_path_); + if (ret.first != SUCCESS) { + return FAILED; + } + auto json_header = ret.second; + + auto ret2 = GetParentDir(file_path_); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : json_header["shard_addresses"]) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } ShardHeader header = ShardHeader(); - if (header.Build(file_path_) != SUCCESS) { - MS_LOG(ERROR) << "Build shard schema failed."; + if (header.BuildDataset(real_addresses) == FAILED) { return FAILED; } shard_header_ = header; diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index bd0394ac4..ed5af0e6a 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -47,20 +47,55 @@ ShardReader::ShardReader() { block_reader_ = false; } -MSRStatus ShardReader::Init(const std::string &file_path) { +std::pair> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { if (!IsLegalFile(file_path)) { + return {FAILED, {}}; + } + auto ret = ShardHeader::BuildSingleHeader(file_path); + if (ret.first != SUCCESS) { + return {FAILED, {}}; + } + auto header = ret.second; + meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, + {"version", header["version"]}, {"index_fields", header["index_fields"]}, + {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; + return {SUCCESS, header["shard_addresses"]}; +} + +MSRStatus ShardReader::Init(const std::vector &file_paths, bool load_dataset) { + std::string file_path = file_paths[0]; + json first_meta_data = json(); + auto ret = GetMeta(file_path, first_meta_data); + if (ret.first != SUCCESS) { return FAILED; } - ShardHeader sh = ShardHeader(); - if (sh.Build(file_path) == FAILED) { + if (file_paths.size() == 1 && load_dataset == true) { + auto ret2 = GetParentDir(file_path); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : ret.second) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } + file_paths_ = real_addresses; + } else if (file_paths.size() >= 1 && load_dataset == false) { + file_paths_ = file_paths; + } else { + MS_LOG(ERROR) << "Error in parameter file_path or load_dataset."; return FAILED; } - shard_header_ = std::make_shared(sh); - header_size_ = shard_header_->GetHeaderSize(); - page_size_ = shard_header_->GetPageSize(); - file_paths_ = shard_header_->GetShardAddresses(); - for (const auto &file : file_paths_) { + json meta_data = json(); + auto ret1 = GetMeta(file, meta_data); + if (ret1.first != SUCCESS) { + return FAILED; + } + if (meta_data != first_meta_data) { + MS_LOG(ERROR) << "Mindrecord files meta information is different."; + return FAILED; + } sqlite3 *db = nullptr; // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); @@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) { } database_paths_.push_back(db); } - + ShardHeader sh = ShardHeader(); + if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) { + return FAILED; + } + shard_header_ = std::make_shared(sh); + header_size_ = shard_header_->GetHeaderSize(); + page_size_ = shard_header_->GetPageSize(); num_rows_ = 0; auto row_group_summary = ReadRowGroupSummary(); for (const auto &rg : row_group_summary) { @@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vectorclose(); return FAILED; } - json label_json = json::from_msgpack(label_raw); json tmp; if (!columns.empty()) { @@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() { return SUCCESS; } -int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) { - ShardHeader sh = ShardHeader(); - if (sh.Build(file_path) == FAILED) { - return -1; - } - auto header = std::make_shared(sh); - auto file_paths = header->GetShardAddresses(); - auto shard_count = file_paths.size(); - auto index_fields = header->GetFields(); +int64_t ShardReader::GetNumClasses(const std::string &category_field) { + auto shard_count = file_paths_.size(); + auto index_fields = shard_header_->GetFields(); std::map map_schema_id_fields; for (auto &field : index_fields) { @@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri std::set categories; for (int x = 0; x < shard_count; x++) { sqlite3 *db = nullptr; - int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); + int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); if (SQLITE_OK != rc) { MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); return -1; @@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri return categories.size(); } -MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr &op, - int64_t *count) { - if (Init(file_path) == FAILED) { +MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, + const std::shared_ptr &op, int64_t *count) { + if (SUCCESS != Init(file_paths, load_dataset)) { return FAILED; } int64_t num_samples = num_rows_; if (std::dynamic_pointer_cast(op)) { auto category_op = std::dynamic_pointer_cast(op); std::string category_field = category_op->GetCategoryField(); - auto num_classes = GetNumClasses(file_path, category_field); + auto num_classes = GetNumClasses(category_field); num_samples = category_op->GetNumSamples(num_rows_, num_classes); } else if (std::dynamic_pointer_cast(op)) { num_samples = op->GetNumSamples(num_rows_, 0); @@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s return SUCCESS; } -MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, +MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, const std::vector &selected_columns, const std::vector> &operators, const bool &block_reader) { // Open file and set header by ShardReader - if (Init(file_path) == FAILED) { - return FAILED; + auto ret = Init(file_paths, load_dataset); + if (SUCCESS != ret) { + return ret; } auto thread_limit = GetMaxThreadNum(); if (n_consumer > thread_limit) { @@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, return SUCCESS; } -MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consumer, +MSRStatus ShardReader::OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer, const std::vector &selected_columns, const std::vector> &operators) { // Open file and set header by ShardReader - if (Init(file_path) == FAILED) { + if (SUCCESS != Init(file_paths, load_dataset)) { return FAILED; } // should remove blob field from selected_columns when call from python diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 4a33bfddb..43967c43c 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (!IsLegalFile(path)) { return FAILED; } - ShardHeader sh = ShardHeader(); - if (sh.Build(path) == FAILED) { + auto ret1 = ShardHeader::BuildSingleHeader(path); + if (ret1.first != SUCCESS) { return FAILED; } - shard_header_ = std::make_shared(sh); - auto paths = shard_header_->GetShardAddresses(); + auto json_header = ret1.second; + auto ret2 = GetParentDir(path); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : json_header["shard_addresses"]) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } + ShardHeader header = ShardHeader(); + if (header.BuildDataset(real_addresses) == FAILED) { + return FAILED; + } + shard_header_ = std::make_shared(header); MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); if (ret == FAILED) { return FAILED; @@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (ret == FAILED) { return FAILED; } - ret = Open(paths, true); + ret = Open(json_header["shard_addresses"], true); if (ret == FAILED) { MS_LOG(ERROR) << "Open file failed"; return FAILED; diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 8db2c6b7c..3adb01735 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -35,8 +35,9 @@ namespace mindrecord { std::atomic thread_status(false); ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared(); } -MSRStatus ShardHeader::InitializeHeader(const std::vector &headers) { +MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool load_dataset) { shard_count_ = headers.size(); + int shard_index = 0; bool first = true; for (const auto &header : headers) { if (first) { @@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector &headers) { header_size_ = header["header_size"].get(); page_size_ = header["page_size"].get(); } - ParsePage(header["page"]); + ParsePage(header["page"], shard_index, load_dataset); + shard_index++; } return SUCCESS; } @@ -136,40 +138,39 @@ std::pair ShardHeader::ValidateHeader(const std::string &path) return {SUCCESS, json_header}; } -MSRStatus ShardHeader::Build(const std::string &file_path) { +std::pair ShardHeader::BuildSingleHeader(const std::string &file_path) { auto ret = ValidateHeader(file_path); if (SUCCESS != ret.first) { - return FAILED; - } - json main_header = ret.second; - json addresses = main_header["shard_addresses"]; - vector real_addresses; - auto ret1 = GetParentDir(file_path); - if (SUCCESS != ret1.first) { - return FAILED; + return {FAILED, json()}; } - std::string parent_dir = ret1.second; + json raw_header = ret.second; + json header = {{"shard_addresses", raw_header["shard_addresses"]}, + {"header_size", raw_header["header_size"]}, + {"page_size", raw_header["page_size"]}, + {"index_fields", raw_header["index_fields"]}, + {"blob_fields", raw_header["schema"][0]["blob_fields"]}, + {"schema", raw_header["schema"][0]["schema"]}, + {"version", raw_header["version"]}}; + return {SUCCESS, header}; +} - for (const auto &addr : addresses) { - std::string absolute_path = parent_dir + string(addr); - real_addresses.emplace_back(absolute_path); - } +MSRStatus ShardHeader::BuildDataset(const std::vector &file_paths, bool load_dataset) { uint32_t thread_num = std::thread::hardware_concurrency(); if (thread_num == 0) thread_num = kThreadNumber; uint32_t work_thread_num = 0; - uint32_t addr_count = real_addresses.size(); - int group_num = ceil(addr_count * 1.0 / thread_num); + uint32_t shard_count = file_paths.size(); + int group_num = ceil(shard_count * 1.0 / thread_num); std::vector thread_set(thread_num); - std::vector headers(addr_count); + std::vector headers(shard_count); for (uint32_t x = 0; x < thread_num; ++x) { int start_num = x * group_num; - int end_num = ((x + 1) * group_num > addr_count) ? addr_count : (x + 1) * group_num; + int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num; if (start_num >= end_num) { continue; } thread_set[x] = - std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), real_addresses); + std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths); work_thread_num++; } @@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) { thread_status = false; return FAILED; } - if (SUCCESS != InitializeHeader(headers)) { + if (SUCCESS != InitializeHeader(headers, load_dataset)) { return FAILED; } return SUCCESS; @@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { return SUCCESS; } -void ShardHeader::ParsePage(const json &pages) { +void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { + // set shard_index when load_dataset is false if (pages_.empty() && shard_count_ <= kMaxShardCount) { pages_.resize(shard_count_); } @@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) { std::shared_ptr parsed_page = std::make_shared(page_id, shard_id, page_type, page_type_id, start_row_id, end_row_id, row_group_ids, page_size); - pages_[shard_id].push_back(std::move(parsed_page)); + if (load_dataset == true) { + pages_[shard_id].push_back(std::move(parsed_page)); + } else { + pages_[shard_index].push_back(std::move(parsed_page)); + } } } @@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { std::string line; while (std::getline(page_in_handle, line)) { - ParsePage(json::parse(line)); + ParsePage(json::parse(line), -1, true); } page_in_handle.close(); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 73bd025e1..11ae1a2d8 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset): A source dataset that reads from shard files and database. Args: - dataset_file (str): one of file names in dataset. + dataset_file (str, list[str]): One of file names or file list in dataset. columns_list (list[str], optional): List of columns to be read (default=None). num_parallel_workers (int, optional): The number of readers (default=None). shuffle (bool, optional): Whether or not to perform shuffle on the dataset @@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset): shuffle=None, num_shards=None, shard_id=None, block_reader=False, sampler=None): super().__init__(num_parallel_workers) + if isinstance(dataset_file, list): + self.load_dataset = False + else: + self.load_dataset = True self.dataset_file = dataset_file self.columns_list = columns_list self.global_shuffle = shuffle @@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset): def get_args(self): args = super().get_args() args["dataset_file"] = self.dataset_file + args["load_dataset"] = self.load_dataset args["columns_list"] = self.columns_list args["global_shuffle"] = self.global_shuffle args["partitions"] = self.partitions @@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset): Return: Number, number of batches. """ - - num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler) + if self.load_dataset: + dataset_file = [self.dataset_file] + else: + dataset_file = self.dataset_file + num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler) if self.partitions is not None and self.partitions[0] > 0: if num_rows % self.partitions[0] == 0: num_rows = num_rows // self.partitions[0] diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index cd67ef326..9b8c130a8 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -529,8 +529,11 @@ def check_minddataset(method): dataset_file = param_dict.get('dataset_file') if dataset_file is None: raise ValueError("dataset_file is not provided.") - check_dataset_file(dataset_file) - + if isinstance(dataset_file, list): + for f in dataset_file: + check_dataset_file(f) + else: + check_dataset_file(dataset_file) check_param_type(nreq_param_int, param_dict, int) check_param_type(nreq_param_list, param_dict, list) diff --git a/mindspore/mindrecord/filereader.py b/mindspore/mindrecord/filereader.py index 80d6ffc78..ba48fb8cc 100644 --- a/mindspore/mindrecord/filereader.py +++ b/mindspore/mindrecord/filereader.py @@ -28,7 +28,7 @@ class FileReader: Class to read MindRecord File series. Args: - file_name (str): File name of MindRecord File. + file_name (str, list[str]): One of MindRecord File or file list. num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). It should not be smaller than 1 or larger than the number of CPU. columns (list[str], optional): List of fields which correspond data would be read (default=None). @@ -38,8 +38,11 @@ class FileReader: ParamValueError: If file_name, num_consumer or columns is invalid. """ def __init__(self, file_name, num_consumer=4, columns=None, operator=None): - check_filename(file_name) - self._file_name = file_name + if isinstance(file_name, list): + for f in file_name: + check_filename(f) + else: + check_filename(file_name) if num_consumer is not None: if isinstance(num_consumer, int): diff --git a/mindspore/mindrecord/mindpage.py b/mindspore/mindrecord/mindpage.py index 4baaa6013..6b88ba85a 100644 --- a/mindspore/mindrecord/mindpage.py +++ b/mindspore/mindrecord/mindpage.py @@ -28,7 +28,7 @@ class MindPage: Class to read MindRecord File series in pagination. Args: - file_name (str): File name of MindRecord File. + file_name (str): One of MindRecord File or file list. num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). It should not be smaller than 1 or larger than the number of CPU. @@ -37,8 +37,11 @@ class MindPage: MRMInitSegmentError: If failed to initialize ShardSegment. """ def __init__(self, file_name, num_consumer=4): - check_filename(file_name) - self._file_name = file_name + if isinstance(file_name, list): + for f in file_name: + check_filename(f) + else: + check_filename(file_name) if num_consumer is not None: if isinstance(num_consumer, int): diff --git a/mindspore/mindrecord/shardreader.py b/mindspore/mindrecord/shardreader.py index de4b82f95..f3fc6ffc3 100644 --- a/mindspore/mindrecord/shardreader.py +++ b/mindspore/mindrecord/shardreader.py @@ -35,7 +35,7 @@ class ShardReader: Open file and prepare to read MindRecord File. Args: - file_name (str): File name of MindRecord File. + file_name (str, list[str]): File names of MindRecord File. num_consumer (int): Number of worker threads which load data in parallel. Default: 4. columns (list[str]): List of fields which correspond data would be read. operator(int): Reserved parameter for operators. Default: None. @@ -48,7 +48,12 @@ class ShardReader: """ columns = columns if columns else [] operator = operator if operator else [] - ret = self._reader.open(file_name, num_consumer, columns, operator) + if isinstance(file_name, list): + load_dataset = False + else: + load_dataset = True + file_name = [file_name] + ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator) if ret != ms.MSRStatus.SUCCESS: logger.error("Failed to open {}.".format(file_name)) raise MRMOpenError diff --git a/mindspore/mindrecord/shardsegment.py b/mindspore/mindrecord/shardsegment.py index 963fe25fb..1d450b619 100644 --- a/mindspore/mindrecord/shardsegment.py +++ b/mindspore/mindrecord/shardsegment.py @@ -40,7 +40,7 @@ class ShardSegment: Initialize the ShardSegment. Args: - file_name (str): File name of MindRecord File. + file_name (str, list[str]): File names of MindRecord File. num_consumer (int): Number of worker threads which load data in parallel. Default: 4. columns (list[str]): List of fields which correspond data would be read. operator(int): Reserved parameter for operators. Default: None. @@ -53,7 +53,12 @@ class ShardSegment: """ self._columns = columns if columns else [] operator = operator if operator else [] - ret = self._segment.open(file_name, num_consumer, self._columns, operator) + if isinstance(file_name, list): + load_dataset = False + else: + load_dataset = True + file_name = [file_name] + ret = self._segment.open(file_name, load_dataset, num_consumer, self._columns, operator) if ret != SUCCESS: logger.error("Failed to open {}.".format(file_name)) raise MRMOpenError diff --git a/tests/ut/cpp/dataset/mind_record_op_test.cc b/tests/ut/cpp/dataset/mind_record_op_test.cc index 90f41fdeb..b2cbdf027 100644 --- a/tests/ut/cpp/dataset/mind_record_op_test.cc +++ b/tests/ut/cpp/dataset/mind_record_op_test.cc @@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) { std::shared_ptr my_mindrecord_op; MindRecordOp::Builder builder; - builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") + builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) + .SetLoadDataset(true) .SetRowsPerBuffer(3) .SetNumMindRecordWorkers(4) .SetColumnsToLoad(column_list); @@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) { std::shared_ptr my_mindrecord_op; MindRecordOp::Builder builder; - builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") + builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) + .SetLoadDataset(true) .SetRowsPerBuffer(3) .SetNumMindRecordWorkers(4) .SetColumnsToLoad(column_list) @@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) { std::shared_ptr my_mindrecord_op; MindRecordOp::Builder builder; - builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") + builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) + .SetLoadDataset(true) .SetRowsPerBuffer(3) .SetNumMindRecordWorkers(4) .SetColumnsToLoad(column_list) @@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) { std::shared_ptr my_mindrecord_op; MindRecordOp::Builder builder; - builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") + builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) + .SetLoadDataset(true) .SetRowsPerBuffer(3) .SetNumMindRecordWorkers(4) .SetColumnsToLoad(column_list) @@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) { std::shared_ptr my_mindrecord_op; MindRecordOp::Builder builder; - builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") + builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) + .SetLoadDataset(true) .SetRowsPerBuffer(3) .SetNumMindRecordWorkers(4) .SetColumnsToLoad(column_list); @@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) { std::shared_ptr my_mindrecord_op; MindRecordOp::Builder builder; - builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") + builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) + .SetLoadDataset(true) .SetRowsPerBuffer(3) .SetNumMindRecordWorkers(4) .SetBlockReader() @@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) { std::shared_ptr my_mindrecord_op; MindRecordOp::Builder builder; - builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") + builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) + .SetLoadDataset(true) .SetRowsPerBuffer(3) .SetNumMindRecordWorkers(4) .SetColumnsToLoad(column_list); diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 6fc5ccbbe..23c2c1e34 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) { std::vector> ops; ops.push_back(std::make_shared(kSampleCount)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) { ops.push_back(std::make_shared(kNum, kDen)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) { ops.push_back(std::make_shared(kNum, kDen)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { ASSERT_TRUE(partitions.second == 2); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { ops.push_back(std::make_shared("label", 2)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name},true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { ops.push_back(std::make_shared("label", 2, 3, 0)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name},true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) { ops.push_back(std::make_shared(categories)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) { ops.push_back(std::make_shared(1)); ShardReader dataset; - dataset.Open(file_name, 16, column_list, ops); + dataset.Open({file_name}, true, 16, column_list, ops); dataset.Launch(); int i = 0; @@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) { ops.push_back(std::make_shared(1)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) { ops.push_back(std::make_shared(kSampleSize)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) { ops.push_back(std::make_shared(35)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) { ops.push_back(std::make_shared(1)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); ShardReader compare_dataset; - compare_dataset.Open(file_name, 4, column_list); + compare_dataset.Open({file_name},true, 4, column_list); compare_dataset.Launch(); int i = 0; @@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) { ops.push_back(std::make_shared(21)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) { ops.push_back(std::make_shared(categories)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) { ops.push_back(std::make_shared(categories)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name},true, 4, column_list, ops); dataset.Launch(); int i = 0; @@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) { ops.push_back(std::make_shared(100)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); int i = 0; diff --git a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc index e88c2fe3d..c532fe28b 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc @@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { auto column_list = std::vector{"file_name"}; ShardReader dataset; - dataset.Open(file_name, 4, column_list); + dataset.Open({file_name}, true, 4, column_list); dataset.Launch(); while (true) { @@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) { std::vector> ops; ops.push_back(std::make_shared(17)); ShardReader dataset; - dataset.Open(file_name, 4, column_list, ops); + dataset.Open({file_name}, true, 4, column_list, ops); dataset.Launch(); while (true) { @@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) { ops.push_back(std::make_shared(3)); ShardReader dataset; const bool kBlockReader = true; - dataset.Open(file_name, 4, column_list, ops, kBlockReader); + dataset.Open({file_name}, true, 4, column_list, ops, kBlockReader); dataset.Launch(); while (true) { @@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) { MS_LOG(INFO) << FormatInfo("Test read imageNet"); std::string file_name = "./imagenet.shard01"; ShardReader dataset; - dataset.Open(file_name); + dataset.Open({file_name}, true); dataset.Launch(); while (true) { @@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { std::string file_name = "./imagenet.shard01"; auto column_list = std::vector{"label"}; ShardReader dataset; - MSRStatus ret = dataset.Open(file_name, 4, column_list); + MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); ASSERT_EQ(ret, SUCCESS); dataset.Launch(); @@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { std::string file_name = "./imagenet.shard01"; auto column_list = std::vector{"file_namex"}; ShardReader dataset; - MSRStatus ret = dataset.Open(file_name, 4, column_list); + MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST); } @@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) { MS_LOG(INFO) << FormatInfo("Test shard version"); std::string file_name = "./imagenet.shard01"; ShardReader dataset; - MSRStatus ret = dataset.Open(file_name, 4); + MSRStatus ret = dataset.Open({file_name}, true, 4); ASSERT_EQ(ret, SUCCESS); dataset.Launch(); @@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) { auto column_list = std::vector{"file_name"}; ShardReader dataset; - MSRStatus ret = dataset.Open(file_name, 4, column_list); + MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); ASSERT_EQ(ret, FAILED); } @@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) { auto column_list = std::vector{"file_name"}; ShardReader dataset; - dataset.Open(file_name, -481565535, column_list); + dataset.Open({file_name}, true, -481565535, column_list); dataset.Launch(); while (true) { diff --git a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc index bf0a35df7..3fa681235 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc @@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) { std::string file_name = "./imagenet.shard01"; ShardSegment dataset; - dataset.Open(file_name, 4); + dataset.Open({file_name}, true, 4); auto x = dataset.GetCategoryFields(); for (const auto &fields : x.second) { @@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { std::string file_name = "./imagenet.shard01"; ShardSegment dataset; - dataset.Open(file_name, 4); + dataset.Open({file_name}, true, 4); auto x = dataset.GetCategoryFields(); for (const auto &fields : x.second) { @@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { std::string file_name = "./imagenet.shard01"; ShardSegment dataset; - dataset.Open(file_name, 4); + dataset.Open({file_name}, true, 4); auto x = dataset.GetCategoryFields(); for (const auto &fields : x.second) { @@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { std::string file_name = "./imagenet.shard01"; ShardSegment dataset; - dataset.Open(file_name, 4); + dataset.Open({file_name}, true, 4); auto x = dataset.GetCategoryFields(); for (const auto &fields : x.second) { @@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { std::string file_name = "./imagenet.shard01"; ShardSegment dataset; - dataset.Open(file_name, 4); + dataset.Open({file_name}, true, 4); auto x = dataset.GetCategoryFields(); for (const auto &fields : x.second) { diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc index 71da456e7..159efbf2f 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc @@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { std::string filename = "./OneSample.shard01"; ShardReader dataset; - MSRStatus ret = dataset.Open(filename, 4); + MSRStatus ret = dataset.Open({filename}, true, 4); ASSERT_EQ(ret, SUCCESS); dataset.Launch(); @@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { filename = "./imagenet.shard01"; auto column_list = std::vector{"label", "file_name", "data"}; ShardReader dataset; - MSRStatus ret = dataset.Open(filename, 4, column_list); + MSRStatus ret = dataset.Open({filename}, true, 4, column_list); ASSERT_EQ(ret, SUCCESS); dataset.Launch(); @@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) { filename = "./imagenet.shard01"; auto column_list = std::vector{"label", "file_name"}; ShardReader dataset; - MSRStatus ret = dataset.Open(filename, 4, column_list); + MSRStatus ret = dataset.Open({filename}, true, 4, column_list); ASSERT_EQ(ret, SUCCESS); dataset.Launch(); @@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { filename = "./imagenet.shard01"; auto column_list = std::vector{"label", "data"}; ShardReader dataset; - MSRStatus ret = dataset.Open(filename, 4, column_list); + MSRStatus ret = dataset.Open({filename}, true, 4, column_list); ASSERT_EQ(ret, SUCCESS); dataset.Launch(); @@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { filename = "./TenSampleFortyShard.shard01"; ShardReader dataset; - MSRStatus ret = dataset.Open(filename, 4); + MSRStatus ret = dataset.Open({filename}, true, 4); ASSERT_EQ(ret, SUCCESS); dataset.Launch(); diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index ba0c86dc8..02cad1e6c 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter FILES_NUM = 4 CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" +CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord" +CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord" CV_DIR_NAME = "../data/mindrecord/testImageNetData" NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" @@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial(): os.remove("{}".format(x)) os.remove("{}.db".format(x)) - def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): """tutorial for cv minddataset.""" columns_list = ["data", "file_name", "label"] @@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem assert num_iter == 20 +def test_cv_minddataset_reader_file_list(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 10 + +def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + data_set = ds.MindDataset([CV_FILE_NAME + "0"], columns_list, num_readers) + assert data_set.get_dataset_size() < 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter < 10 + +def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + if os.path.exists(CV1_FILE_NAME): + os.remove(CV1_FILE_NAME) + if os.path.exists("{}.db".format(CV1_FILE_NAME)): + os.remove("{}.db".format(CV1_FILE_NAME)) + if os.path.exists(CV2_FILE_NAME): + os.remove(CV2_FILE_NAME) + if os.path.exists("{}.db".format(CV2_FILE_NAME)): + os.remove("{}.db".format(CV2_FILE_NAME)) + writer = FileWriter(CV1_FILE_NAME, 1) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "CV1_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + + writer = FileWriter(CV2_FILE_NAME, 1) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "CV2_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + columns_list = ["data", "file_name", "label"] + num_readers = 4 + data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME], columns_list, num_readers) + assert data_set.get_dataset_size() == 30 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 30 + if os.path.exists(CV1_FILE_NAME): + os.remove(CV1_FILE_NAME) + if os.path.exists("{}.db".format(CV1_FILE_NAME)): + os.remove("{}.db".format(CV1_FILE_NAME)) + if os.path.exists(CV2_FILE_NAME): + os.remove(CV2_FILE_NAME) + if os.path.exists("{}.db".format(CV2_FILE_NAME)): + os.remove("{}.db".format(CV2_FILE_NAME)) + +def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): + paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + for x in paths: + os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None + os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None + writer = FileWriter(CV1_FILE_NAME, FILES_NUM) + data = get_data(CV_DIR_NAME) + cv_schema_json = {"id": {"type": "int32"}, + "file_name": {"type": "string"}, + "label": {"type": "int32"}, + "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "CV1_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + + columns_list = ["data", "file_name", "label"] + num_readers = 4 + data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + [CV1_FILE_NAME + str(x) for x in range(2, 4)], columns_list, num_readers) + assert data_set.get_dataset_size() < 20 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) + logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) + logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter < 20 + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + + def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index 2a269ffc8..53a719a98 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -22,6 +22,7 @@ import mindspore.dataset as ds from mindspore.mindrecord import FileWriter CV_FILE_NAME = "./imagenet.mindrecord" +CV1_FILE_NAME = "./imagenet1.mindrecord" def create_cv_mindrecord(files_num): @@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num): writer.commit() +def create_diff_schema_cv_mindrecord(files_num): + """tutorial for cv dataset writer.""" + os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None + os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None + writer = FileWriter(CV1_FILE_NAME, files_num) + cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} + data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}] + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name_1", "label"]) + writer.write_raw_data(data) + writer.commit() + +def create_diff_page_size_cv_mindrecord(files_num): + """tutorial for cv dataset writer.""" + os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None + os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None + writer = FileWriter(CV1_FILE_NAME, files_num) + writer.set_page_size(1<< 26) #64MB + cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} + data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}] + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.write_raw_data(data) + writer.commit() + def test_cv_lack_json(): """tutorial for cv minderdataset.""" create_cv_mindrecord(1) @@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle(): os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) +def test_cv_minddataset_reader_different_schema(): + create_cv_mindrecord(1) + create_diff_schema_cv_mindrecord(1) + columns_list = ["data", "label"] + num_readers = 4 + with pytest.raises(Exception, match="MindRecordOp init failed"): + data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, + num_readers) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + os.remove(CV1_FILE_NAME) + os.remove("{}.db".format(CV1_FILE_NAME)) + +def test_cv_minddataset_reader_different_page_size(): + create_cv_mindrecord(1) + create_diff_page_size_cv_mindrecord(1) + columns_list = ["data", "label"] + num_readers = 4 + with pytest.raises(Exception, match="MindRecordOp init failed"): + data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, + num_readers) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + os.remove(CV1_FILE_NAME) + os.remove("{}.db".format(CV1_FILE_NAME)) diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 778ebccf8..da424122b 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial(): assert count == 10 reader.close() +def test_cv_file_reader_file_list(): + """tutorial for cv file partial reader.""" + reader = FileReader([CV_FILE_NAME + str(x) for x in range(FILES_NUM)]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 3 + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 10 + def test_cv_file_reader_partial_tutorial(): """tutorial for cv file partial reader.""" reader = FileReader(CV_FILE_NAME + "0") -- GitLab