提交 1a98c6b4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!999 [MD] mindrecord support reading file list

Merge pull request !999 from liyong126/mindrecord_file_list
...@@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas ...@@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
} }
std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>(); std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>();
(void)builder->SetDatasetFile(ToString(args["dataset_file"])); bool load_dataset = ToBool(args["load_dataset"]);
if (load_dataset == true) {
(void)builder->SetDatasetFile({ToString(args["dataset_file"])});
} else {
(void)builder->SetDatasetFile(ToStringVector(args["dataset_file"]));
}
(void)builder->SetLoadDataset(load_dataset);
std::vector<std::string> in_col_names; std::vector<std::string> in_col_names;
if (!args["columns_list"].is_none()) { if (!args["columns_list"].is_none()) {
in_col_names = ToStringVector(args["columns_list"]); in_col_names = ToStringVector(args["columns_list"]);
......
...@@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) { ...@@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) {
}); });
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp") (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
.def_static("get_num_rows", [](const std::string &path, const py::object &sampler) { .def_static("get_num_rows",
int64_t count = 0; [](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler) {
std::shared_ptr<mindrecord::ShardOperator> op; int64_t count = 0;
if (py::hasattr(sampler, "_create_for_minddataset")) { std::shared_ptr<mindrecord::ShardOperator> op;
auto create = sampler.attr("_create_for_minddataset"); if (py::hasattr(sampler, "_create_for_minddataset")) {
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); auto create = sampler.attr("_create_for_minddataset");
} op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count)); }
return count; THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count));
}); return count;
});
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp") (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
.def_static("get_num_rows_and_classes", .def_static("get_num_rows_and_classes",
......
...@@ -40,7 +40,7 @@ using mindrecord::ShardOperator; ...@@ -40,7 +40,7 @@ using mindrecord::ShardOperator;
using mindrecord::ShardReader; using mindrecord::ShardReader;
// Builder constructor. Creates the builder object. // Builder constructor. Creates the builder object.
MindRecordOp::Builder::Builder() : build_dataset_file_("") { MindRecordOp::Builder::Builder() : build_dataset_file_({}) {
// Some arguments to the MindRecordOp constructor have a default argument that is taken // Some arguments to the MindRecordOp constructor have a default argument that is taken
// from the client config. // from the client config.
// The user may choose to change these values for the construction of the StorageOp by // The user may choose to change these values for the construction of the StorageOp by
...@@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) { ...@@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
"Building a MindRecordOp that has not provided a file."); "Building a MindRecordOp that has not provided a file.");
} }
new_mind_record_op = std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_rows_per_buffer_, new_mind_record_op = std::make_shared<MindRecordOp>(
build_dataset_file_, build_op_connector_queue_size_, build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_,
build_columns_to_load_, build_operators_, build_block_reader_); build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_);
RETURN_IF_NOT_OK(new_mind_record_op->Init()); RETURN_IF_NOT_OK(new_mind_record_op->Init());
...@@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) { ...@@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); }
// Constructor of the MindRecordOp. // Constructor of the MindRecordOp.
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file, MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer,
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, std::vector<std::string> dataset_file, bool load_dataset, int32_t op_connector_queue_size,
const std::vector<std::string> &columns_to_load,
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader)
: ParallelOp(num_mind_record_workers, op_connector_queue_size), : ParallelOp(num_mind_record_workers, op_connector_queue_size),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
dataset_file_(dataset_file), dataset_file_(dataset_file),
load_dataset_(load_dataset),
columns_to_load_(columns_to_load), columns_to_load_(columns_to_load),
operators_(operators), operators_(operators),
num_mind_record_workers_(num_mind_record_workers), num_mind_record_workers_(num_mind_record_workers),
...@@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf ...@@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
// Private helper method to encapsulate some common construction/reset tasks // Private helper method to encapsulate some common construction/reset tasks
Status MindRecordOp::Init() { Status MindRecordOp::Init() {
shard_reader_ = std::make_unique<ShardReader>(); shard_reader_ = std::make_unique<ShardReader>();
auto rc = shard_reader_->Open(dataset_file_, num_mind_record_workers_, columns_to_load_, operators_, block_reader_); auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_,
block_reader_);
CHECK_FAIL_RETURN_UNEXPECTED(rc != MSRStatus::FAILED, CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS,
"MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc));
data_schema_ = std::make_unique<DataSchema>(); data_schema_ = std::make_unique<DataSchema>();
...@@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { ...@@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all); ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff // Then show any custom derived-internal stuff
out << "\n1 Dataset file : " << dataset_file_ << "\nNumber of rows : " << num_rows_ out << "\n Dataset file : ";
<< "\nRows per buffer : " << rows_per_buffer_ << "\nNumber of buffers : " << buffers_needed_ for (auto &file : dataset_file_) {
out << file << " ";
}
out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_
<< "\nNumber of buffers : " << buffers_needed_
<< "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n"; << "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n";
} }
} }
...@@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() { ...@@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
return Status::OK(); return Status::OK();
} }
Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op, Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
int64_t *count) { const std::shared_ptr<ShardOperator> &op, int64_t *count) {
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count); MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count);
if (rc == MSRStatus::FAILED) { if (rc == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed.");
} }
......
...@@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp { ...@@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp {
return *this; return *this;
} }
Builder &SetDatasetFile(const std::string &file) { Builder &SetDatasetFile(const std::vector<std::string> &files) {
build_dataset_file_ = file; build_dataset_file_ = files;
return *this; return *this;
} }
...@@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp { ...@@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp {
return *this; return *this;
} }
Builder &SetLoadDataset(bool load_dataset) {
build_load_dataset_ = load_dataset;
return *this;
}
Status SanityCheck() const; Status SanityCheck() const;
static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; }
...@@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp { ...@@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp {
int32_t builder_num_workers_; int32_t builder_num_workers_;
int32_t build_rows_per_buffer_; int32_t build_rows_per_buffer_;
int32_t build_op_connector_queue_size_; int32_t build_op_connector_queue_size_;
std::string build_dataset_file_; std::vector<std::string> build_dataset_file_;
bool build_load_dataset_;
std::vector<std::string> build_columns_to_load_; std::vector<std::string> build_columns_to_load_;
std::vector<std::shared_ptr<ShardOperator>> build_operators_; std::vector<std::shared_ptr<ShardOperator>> build_operators_;
bool build_block_reader_; bool build_block_reader_;
...@@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp { ...@@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp {
// @note The builder class should be used to call it // @note The builder class should be used to call it
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader) // @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
// @param rows_per_buffer - The requested number of rows per buffer // @param rows_per_buffer - The requested number of rows per buffer
// @param dataset_file - A shard file // @param dataset_file - dataset files
// @param op_connector_queue_size - The output connector queue size // @param op_connector_queue_size - The output connector queue size
// @param columns_to_load - The list of columns to use (column name) // @param columns_to_load - The list of columns to use (column name)
// @param operators - ShardOperators for Shuffle, Category, Sample // @param operators - ShardOperators for Shuffle, Category, Sample
MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file, MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector<std::string> dataset_file,
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, bool load_dataset, int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader); const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader);
// Destructor // Destructor
...@@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp { ...@@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp {
// Getter method // Getter method
int32_t num_rows() const { return num_rows_; } int32_t num_rows() const { return num_rows_; }
// Getter method static Status CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op, const std::shared_ptr<ShardOperator> &op, int64_t *count);
int64_t *count);
// Getter method // Getter method
int32_t rows_per_buffer() const { return rows_per_buffer_; } int32_t rows_per_buffer() const { return rows_per_buffer_; }
// Getter method // Getter method
std::string dataset_file() const { return dataset_file_; } std::vector<std::string> dataset_file() const { return dataset_file_; }
// Getter method // Getter method
std::vector<std::string> columns_to_load() const { return columns_to_load_; } std::vector<std::string> columns_to_load() const { return columns_to_load_; }
bool block_reader() const { return block_reader_; } bool block_reader() const { return block_reader_; }
bool load_dataset() const { return load_dataset_; }
Status Init(); Status Init();
Status SetColumnsBlob(); Status SetColumnsBlob();
...@@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp { ...@@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp {
Status FetchBlockBuffer(const int32_t &buffer_id); Status FetchBlockBuffer(const int32_t &buffer_id);
int32_t rows_per_buffer_; // The number of requested rows per buffer. int32_t rows_per_buffer_; // The number of requested rows per buffer.
std::string dataset_file_; // A dataset file std::vector<std::string> dataset_file_; // dataset files
bool load_dataset_; // load dataset from single file or not
std::vector<std::string> columns_to_load_; // Columns to load from dataset std::vector<std::string> columns_to_load_; // Columns to load from dataset
std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use
int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader
......
...@@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) { ...@@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) {
case IO_FAILED: case IO_FAILED:
return "io operate failed"; return "io operate failed";
break; break;
case MATCH_HEADER_FAILED:
return "match header failed";
break;
default: default:
return "invalid error no"; return "invalid error no";
} }
......
...@@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) { ...@@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) {
void BindShardReader(const py::module *m) { void BindShardReader(const py::module *m) {
(void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local()) (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
.def(py::init<>()) .def(py::init<>())
.def("open", (MSRStatus(ShardReader::*)(const std::string &, const int &, const std::vector<std::string> &, .def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &,
const std::vector<std::string> &,
const std::vector<std::shared_ptr<ShardOperator>> &)) & const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardReader::OpenPy) ShardReader::OpenPy)
.def("launch", &ShardReader::Launch) .def("launch", &ShardReader::Launch)
...@@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) { ...@@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) {
void BindShardSegment(py::module *m) { void BindShardSegment(py::module *m) {
(void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local()) (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
.def(py::init<>()) .def(py::init<>())
.def("open", (MSRStatus(ShardSegment::*)(const std::string &, const int &, const std::vector<std::string> &, .def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &,
const std::vector<std::string> &,
const std::vector<std::shared_ptr<ShardOperator>> &)) & const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardSegment::OpenPy) ShardSegment::OpenPy)
.def("get_category_fields", .def("get_category_fields",
......
...@@ -72,7 +72,8 @@ enum MSRStatus { ...@@ -72,7 +72,8 @@ enum MSRStatus {
ILLEGAL_PARAMETERS, ILLEGAL_PARAMETERS,
GET_PAGE_BY_GROUP_ID_FAILED, GET_PAGE_BY_GROUP_ID_FAILED,
GET_SYSTEM_STATE_FAILED, GET_SYSTEM_STATE_FAILED,
IO_FAILED IO_FAILED,
MATCH_HEADER_FAILED
}; };
// convert error no to string message // convert error no to string message
......
...@@ -35,10 +35,11 @@ class ShardHeader { ...@@ -35,10 +35,11 @@ class ShardHeader {
public: public:
ShardHeader(); ShardHeader();
MSRStatus Build(const std::string &file_path);
~ShardHeader() = default; ~ShardHeader() = default;
MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);
static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path);
/// \brief add the schema and save it /// \brief add the schema and save it
/// \param[in] schema the schema needs to be added /// \param[in] schema the schema needs to be added
/// \return the last schema's id /// \return the last schema's id
...@@ -126,7 +127,7 @@ class ShardHeader { ...@@ -126,7 +127,7 @@ class ShardHeader {
MSRStatus FileToPages(const std::string dump_file_name); MSRStatus FileToPages(const std::string dump_file_name);
private: private:
MSRStatus InitializeHeader(const std::vector<json> &headers); MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);
/// \brief get the headers from all the shard data /// \brief get the headers from all the shard data
/// \param[in] the shard data real path /// \param[in] the shard data real path
...@@ -137,9 +138,9 @@ class ShardHeader { ...@@ -137,9 +138,9 @@ class ShardHeader {
MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
/// \brief check the binary file status /// \brief check the binary file status
MSRStatus CheckFileStatus(const std::string &path); static MSRStatus CheckFileStatus(const std::string &path);
std::pair<MSRStatus, json> ValidateHeader(const std::string &path); static std::pair<MSRStatus, json> ValidateHeader(const std::string &path);
void ParseHeader(const json &header); void ParseHeader(const json &header);
...@@ -149,7 +150,7 @@ class ShardHeader { ...@@ -149,7 +150,7 @@ class ShardHeader {
MSRStatus CheckIndexField(const std::string &field, const json &schema); MSRStatus CheckIndexField(const std::string &field, const json &schema);
void ParsePage(const json &page); void ParsePage(const json &page, int shard_index, bool load_dataset);
MSRStatus ParseStatistics(const json &statistics); MSRStatus ParseStatistics(const json &statistics);
......
...@@ -68,23 +68,25 @@ class ShardReader { ...@@ -68,23 +68,25 @@ class ShardReader {
virtual ~ShardReader(); virtual ~ShardReader();
/// \brief open files and initialize reader, c++ API /// \brief open files and initialize reader, c++ API
/// \param[in] file_path the path of ONE file, any file in dataset is fine /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] n_consumer number of threads when reading /// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated /// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category /// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \param[in] block_reader block-reader mode if true, otherwise row-reader mode /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus Open(const std::string &file_path, int n_consumer = 4, MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
const std::vector<std::string> &selected_columns = {}, const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const bool &block_reader = false); const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const bool &block_reader = false);
/// \brief open files and initialize reader, python API /// \brief open files and initialize reader, python API
/// \param[in] file_path the path of ONE file, any file in dataset is fine /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] n_consumer number of threads when reading /// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated /// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category /// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus OpenPy(const std::string &file_path, const int &n_consumer = 4, MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4,
const std::vector<std::string> &selected_columns = {}, const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}); const std::vector<std::shared_ptr<ShardOperator>> &operators = {});
...@@ -114,11 +116,13 @@ class ShardReader { ...@@ -114,11 +116,13 @@ class ShardReader {
int GetShardCount() const; int GetShardCount() const;
/// \brief get the number of rows in database /// \brief get the number of rows in database
/// \param[in] file_path the path of ONE file, any file in dataset is fine /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object /// \param[in] op smart pointer refer to ShardCategory or ShardSample object
/// \param[out] count # of rows /// \param[out] count # of rows
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, int64_t *count); MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count);
/// \brief shuffle task with incremental seed /// \brief shuffle task with incremental seed
/// \return void /// \return void
...@@ -220,7 +224,7 @@ class ShardReader { ...@@ -220,7 +224,7 @@ class ShardReader {
std::vector<std::vector<json>> &column_values); std::vector<std::vector<json>> &column_values);
/// \brief initialize reader /// \brief initialize reader
MSRStatus Init(const std::string &file_path); MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
/// \brief validate column list /// \brief validate column list
MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns); MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns);
...@@ -292,8 +296,9 @@ class ShardReader { ...@@ -292,8 +296,9 @@ class ShardReader {
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories); void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
/// \brief get number of classes /// \brief get number of classes
int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); int64_t GetNumClasses(const std::string &category_field);
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
/// \brief get exactly blob fields data by indices /// \brief get exactly blob fields data by indices
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes, std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
std::vector<uint32_t> &ordered_selected_columns_index); std::vector<uint32_t> &ordered_selected_columns_index);
......
...@@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe ...@@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
write_success_(true) {} write_success_(true) {}
MSRStatus ShardIndexGenerator::Build() { MSRStatus ShardIndexGenerator::Build() {
auto ret = ShardHeader::BuildSingleHeader(file_path_);
if (ret.first != SUCCESS) {
return FAILED;
}
auto json_header = ret.second;
auto ret2 = GetParentDir(file_path_);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : json_header["shard_addresses"]) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
ShardHeader header = ShardHeader(); ShardHeader header = ShardHeader();
if (header.Build(file_path_) != SUCCESS) { if (header.BuildDataset(real_addresses) == FAILED) {
MS_LOG(ERROR) << "Build shard schema failed.";
return FAILED; return FAILED;
} }
shard_header_ = header; shard_header_ = header;
......
...@@ -47,20 +47,55 @@ ShardReader::ShardReader() { ...@@ -47,20 +47,55 @@ ShardReader::ShardReader() {
block_reader_ = false; block_reader_ = false;
} }
MSRStatus ShardReader::Init(const std::string &file_path) { std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
if (!IsLegalFile(file_path)) { if (!IsLegalFile(file_path)) {
return {FAILED, {}};
}
auto ret = ShardHeader::BuildSingleHeader(file_path);
if (ret.first != SUCCESS) {
return {FAILED, {}};
}
auto header = ret.second;
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
{"version", header["version"]}, {"index_fields", header["index_fields"]},
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
return {SUCCESS, header["shard_addresses"]};
}
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
std::string file_path = file_paths[0];
json first_meta_data = json();
auto ret = GetMeta(file_path, first_meta_data);
if (ret.first != SUCCESS) {
return FAILED; return FAILED;
} }
ShardHeader sh = ShardHeader(); if (file_paths.size() == 1 && load_dataset == true) {
if (sh.Build(file_path) == FAILED) { auto ret2 = GetParentDir(file_path);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : ret.second) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
file_paths_ = real_addresses;
} else if (file_paths.size() >= 1 && load_dataset == false) {
file_paths_ = file_paths;
} else {
MS_LOG(ERROR) << "Error in parameter file_path or load_dataset.";
return FAILED; return FAILED;
} }
shard_header_ = std::make_shared<ShardHeader>(sh);
header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize();
file_paths_ = shard_header_->GetShardAddresses();
for (const auto &file : file_paths_) { for (const auto &file : file_paths_) {
json meta_data = json();
auto ret1 = GetMeta(file, meta_data);
if (ret1.first != SUCCESS) {
return FAILED;
}
if (meta_data != first_meta_data) {
MS_LOG(ERROR) << "Mindrecord files meta information is different.";
return FAILED;
}
sqlite3 *db = nullptr; sqlite3 *db = nullptr;
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
...@@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) { ...@@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
} }
database_paths_.push_back(db); database_paths_.push_back(db);
} }
ShardHeader sh = ShardHeader();
if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) {
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(sh);
header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize();
num_rows_ = 0; num_rows_ = 0;
auto row_group_summary = ReadRowGroupSummary(); auto row_group_summary = ReadRowGroupSummary();
for (const auto &rg : row_group_summary) { for (const auto &rg : row_group_summary) {
...@@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str ...@@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
fs->close(); fs->close();
return FAILED; return FAILED;
} }
json label_json = json::from_msgpack(label_raw); json label_json = json::from_msgpack(label_raw);
json tmp; json tmp;
if (!columns.empty()) { if (!columns.empty()) {
...@@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() { ...@@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() {
return SUCCESS; return SUCCESS;
} }
int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) { int64_t ShardReader::GetNumClasses(const std::string &category_field) {
ShardHeader sh = ShardHeader(); auto shard_count = file_paths_.size();
if (sh.Build(file_path) == FAILED) { auto index_fields = shard_header_->GetFields();
return -1;
}
auto header = std::make_shared<ShardHeader>(sh);
auto file_paths = header->GetShardAddresses();
auto shard_count = file_paths.size();
auto index_fields = header->GetFields();
std::map<std::string, int64_t> map_schema_id_fields; std::map<std::string, int64_t> map_schema_id_fields;
for (auto &field : index_fields) { for (auto &field : index_fields) {
...@@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri ...@@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
std::set<std::string> categories; std::set<std::string> categories;
for (int x = 0; x < shard_count; x++) { for (int x = 0; x < shard_count; x++) {
sqlite3 *db = nullptr; sqlite3 *db = nullptr;
int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
if (SQLITE_OK != rc) { if (SQLITE_OK != rc) {
MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db);
return -1; return -1;
...@@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri ...@@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
return categories.size(); return categories.size();
} }
MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
int64_t *count) { const std::shared_ptr<ShardOperator> &op, int64_t *count) {
if (Init(file_path) == FAILED) { if (SUCCESS != Init(file_paths, load_dataset)) {
return FAILED; return FAILED;
} }
int64_t num_samples = num_rows_; int64_t num_samples = num_rows_;
if (std::dynamic_pointer_cast<ShardCategory>(op)) { if (std::dynamic_pointer_cast<ShardCategory>(op)) {
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
std::string category_field = category_op->GetCategoryField(); std::string category_field = category_op->GetCategoryField();
auto num_classes = GetNumClasses(file_path, category_field); auto num_classes = GetNumClasses(category_field);
num_samples = category_op->GetNumSamples(num_rows_, num_classes); num_samples = category_op->GetNumSamples(num_rows_, num_classes);
} else if (std::dynamic_pointer_cast<ShardSample>(op)) { } else if (std::dynamic_pointer_cast<ShardSample>(op)) {
num_samples = op->GetNumSamples(num_rows_, 0); num_samples = op->GetNumSamples(num_rows_, 0);
...@@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s ...@@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s
return SUCCESS; return SUCCESS;
} }
MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
const std::vector<std::string> &selected_columns, const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) { const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) {
// Open file and set header by ShardReader // Open file and set header by ShardReader
if (Init(file_path) == FAILED) { auto ret = Init(file_paths, load_dataset);
return FAILED; if (SUCCESS != ret) {
return ret;
} }
auto thread_limit = GetMaxThreadNum(); auto thread_limit = GetMaxThreadNum();
if (n_consumer > thread_limit) { if (n_consumer > thread_limit) {
...@@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, ...@@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
return SUCCESS; return SUCCESS;
} }
MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consumer, MSRStatus ShardReader::OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
const std::vector<std::string> &selected_columns, const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators) { const std::vector<std::shared_ptr<ShardOperator>> &operators) {
// Open file and set header by ShardReader // Open file and set header by ShardReader
if (Init(file_path) == FAILED) { if (SUCCESS != Init(file_paths, load_dataset)) {
return FAILED; return FAILED;
} }
// should remove blob field from selected_columns when call from python // should remove blob field from selected_columns when call from python
......
...@@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { ...@@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if (!IsLegalFile(path)) { if (!IsLegalFile(path)) {
return FAILED; return FAILED;
} }
ShardHeader sh = ShardHeader(); auto ret1 = ShardHeader::BuildSingleHeader(path);
if (sh.Build(path) == FAILED) { if (ret1.first != SUCCESS) {
return FAILED; return FAILED;
} }
shard_header_ = std::make_shared<ShardHeader>(sh); auto json_header = ret1.second;
auto paths = shard_header_->GetShardAddresses(); auto ret2 = GetParentDir(path);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : json_header["shard_addresses"]) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
ShardHeader header = ShardHeader();
if (header.BuildDataset(real_addresses) == FAILED) {
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(header);
MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize());
if (ret == FAILED) { if (ret == FAILED) {
return FAILED; return FAILED;
...@@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { ...@@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if (ret == FAILED) { if (ret == FAILED) {
return FAILED; return FAILED;
} }
ret = Open(paths, true); ret = Open(json_header["shard_addresses"], true);
if (ret == FAILED) { if (ret == FAILED) {
MS_LOG(ERROR) << "Open file failed"; MS_LOG(ERROR) << "Open file failed";
return FAILED; return FAILED;
......
...@@ -35,8 +35,9 @@ namespace mindrecord { ...@@ -35,8 +35,9 @@ namespace mindrecord {
std::atomic<bool> thread_status(false); std::atomic<bool> thread_status(false);
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); } ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); }
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) { MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
shard_count_ = headers.size(); shard_count_ = headers.size();
int shard_index = 0;
bool first = true; bool first = true;
for (const auto &header : headers) { for (const auto &header : headers) {
if (first) { if (first) {
...@@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) { ...@@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
header_size_ = header["header_size"].get<uint64_t>(); header_size_ = header["header_size"].get<uint64_t>();
page_size_ = header["page_size"].get<uint64_t>(); page_size_ = header["page_size"].get<uint64_t>();
} }
ParsePage(header["page"]); ParsePage(header["page"], shard_index, load_dataset);
shard_index++;
} }
return SUCCESS; return SUCCESS;
} }
...@@ -136,40 +138,39 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path) ...@@ -136,40 +138,39 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path)
return {SUCCESS, json_header}; return {SUCCESS, json_header};
} }
MSRStatus ShardHeader::Build(const std::string &file_path) { std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) {
auto ret = ValidateHeader(file_path); auto ret = ValidateHeader(file_path);
if (SUCCESS != ret.first) { if (SUCCESS != ret.first) {
return FAILED; return {FAILED, json()};
}
json main_header = ret.second;
json addresses = main_header["shard_addresses"];
vector<string> real_addresses;
auto ret1 = GetParentDir(file_path);
if (SUCCESS != ret1.first) {
return FAILED;
} }
std::string parent_dir = ret1.second; json raw_header = ret.second;
json header = {{"shard_addresses", raw_header["shard_addresses"]},
{"header_size", raw_header["header_size"]},
{"page_size", raw_header["page_size"]},
{"index_fields", raw_header["index_fields"]},
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
{"schema", raw_header["schema"][0]["schema"]},
{"version", raw_header["version"]}};
return {SUCCESS, header};
}
for (const auto &addr : addresses) { MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
std::string absolute_path = parent_dir + string(addr);
real_addresses.emplace_back(absolute_path);
}
uint32_t thread_num = std::thread::hardware_concurrency(); uint32_t thread_num = std::thread::hardware_concurrency();
if (thread_num == 0) thread_num = kThreadNumber; if (thread_num == 0) thread_num = kThreadNumber;
uint32_t work_thread_num = 0; uint32_t work_thread_num = 0;
uint32_t addr_count = real_addresses.size(); uint32_t shard_count = file_paths.size();
int group_num = ceil(addr_count * 1.0 / thread_num); int group_num = ceil(shard_count * 1.0 / thread_num);
std::vector<std::thread> thread_set(thread_num); std::vector<std::thread> thread_set(thread_num);
std::vector<json> headers(addr_count); std::vector<json> headers(shard_count);
for (uint32_t x = 0; x < thread_num; ++x) { for (uint32_t x = 0; x < thread_num; ++x) {
int start_num = x * group_num; int start_num = x * group_num;
int end_num = ((x + 1) * group_num > addr_count) ? addr_count : (x + 1) * group_num; int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num;
if (start_num >= end_num) { if (start_num >= end_num) {
continue; continue;
} }
thread_set[x] = thread_set[x] =
std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), real_addresses); std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths);
work_thread_num++; work_thread_num++;
} }
...@@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) { ...@@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) {
thread_status = false; thread_status = false;
return FAILED; return FAILED;
} }
if (SUCCESS != InitializeHeader(headers)) { if (SUCCESS != InitializeHeader(headers, load_dataset)) {
return FAILED; return FAILED;
} }
return SUCCESS; return SUCCESS;
...@@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { ...@@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
return SUCCESS; return SUCCESS;
} }
void ShardHeader::ParsePage(const json &pages) { void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
// set shard_index when load_dataset is false
if (pages_.empty() && shard_count_ <= kMaxShardCount) { if (pages_.empty() && shard_count_ <= kMaxShardCount) {
pages_.resize(shard_count_); pages_.resize(shard_count_);
} }
...@@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) { ...@@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) {
std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id, std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
end_row_id, row_group_ids, page_size); end_row_id, row_group_ids, page_size);
pages_[shard_id].push_back(std::move(parsed_page)); if (load_dataset == true) {
pages_[shard_id].push_back(std::move(parsed_page));
} else {
pages_[shard_index].push_back(std::move(parsed_page));
}
} }
} }
...@@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { ...@@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
std::string line; std::string line;
while (std::getline(page_in_handle, line)) { while (std::getline(page_in_handle, line)) {
ParsePage(json::parse(line)); ParsePage(json::parse(line), -1, true);
} }
page_in_handle.close(); page_in_handle.close();
......
...@@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset): ...@@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset):
A source dataset that reads from shard files and database. A source dataset that reads from shard files and database.
Args: Args:
dataset_file (str): one of file names in dataset. dataset_file (str, list[str]): One of file names or file list in dataset.
columns_list (list[str], optional): List of columns to be read (default=None). columns_list (list[str], optional): List of columns to be read (default=None).
num_parallel_workers (int, optional): The number of readers (default=None). num_parallel_workers (int, optional): The number of readers (default=None).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset shuffle (bool, optional): Whether or not to perform shuffle on the dataset
...@@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset): ...@@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset):
shuffle=None, num_shards=None, shard_id=None, shuffle=None, num_shards=None, shard_id=None,
block_reader=False, sampler=None): block_reader=False, sampler=None):
super().__init__(num_parallel_workers) super().__init__(num_parallel_workers)
if isinstance(dataset_file, list):
self.load_dataset = False
else:
self.load_dataset = True
self.dataset_file = dataset_file self.dataset_file = dataset_file
self.columns_list = columns_list self.columns_list = columns_list
self.global_shuffle = shuffle self.global_shuffle = shuffle
...@@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset): ...@@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset):
def get_args(self): def get_args(self):
args = super().get_args() args = super().get_args()
args["dataset_file"] = self.dataset_file args["dataset_file"] = self.dataset_file
args["load_dataset"] = self.load_dataset
args["columns_list"] = self.columns_list args["columns_list"] = self.columns_list
args["global_shuffle"] = self.global_shuffle args["global_shuffle"] = self.global_shuffle
args["partitions"] = self.partitions args["partitions"] = self.partitions
...@@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset): ...@@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
if self.load_dataset:
num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler) dataset_file = [self.dataset_file]
else:
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler)
if self.partitions is not None and self.partitions[0] > 0: if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0: if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0] num_rows = num_rows // self.partitions[0]
......
...@@ -529,8 +529,11 @@ def check_minddataset(method): ...@@ -529,8 +529,11 @@ def check_minddataset(method):
dataset_file = param_dict.get('dataset_file') dataset_file = param_dict.get('dataset_file')
if dataset_file is None: if dataset_file is None:
raise ValueError("dataset_file is not provided.") raise ValueError("dataset_file is not provided.")
check_dataset_file(dataset_file) if isinstance(dataset_file, list):
for f in dataset_file:
check_dataset_file(f)
else:
check_dataset_file(dataset_file)
check_param_type(nreq_param_int, param_dict, int) check_param_type(nreq_param_int, param_dict, int)
check_param_type(nreq_param_list, param_dict, list) check_param_type(nreq_param_list, param_dict, list)
......
...@@ -28,7 +28,7 @@ class FileReader: ...@@ -28,7 +28,7 @@ class FileReader:
Class to read MindRecord File series. Class to read MindRecord File series.
Args: Args:
file_name (str): File name of MindRecord File. file_name (str, list[str]): One of MindRecord File or file list.
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
It should not be smaller than 1 or larger than the number of CPU. It should not be smaller than 1 or larger than the number of CPU.
columns (list[str], optional): List of fields which correspond data would be read (default=None). columns (list[str], optional): List of fields which correspond data would be read (default=None).
...@@ -38,8 +38,11 @@ class FileReader: ...@@ -38,8 +38,11 @@ class FileReader:
ParamValueError: If file_name, num_consumer or columns is invalid. ParamValueError: If file_name, num_consumer or columns is invalid.
""" """
def __init__(self, file_name, num_consumer=4, columns=None, operator=None): def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
check_filename(file_name) if isinstance(file_name, list):
self._file_name = file_name for f in file_name:
check_filename(f)
else:
check_filename(file_name)
if num_consumer is not None: if num_consumer is not None:
if isinstance(num_consumer, int): if isinstance(num_consumer, int):
......
...@@ -28,7 +28,7 @@ class MindPage: ...@@ -28,7 +28,7 @@ class MindPage:
Class to read MindRecord File series in pagination. Class to read MindRecord File series in pagination.
Args: Args:
file_name (str): File name of MindRecord File. file_name (str): One of MindRecord File or file list.
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
It should not be smaller than 1 or larger than the number of CPU. It should not be smaller than 1 or larger than the number of CPU.
...@@ -37,8 +37,11 @@ class MindPage: ...@@ -37,8 +37,11 @@ class MindPage:
MRMInitSegmentError: If failed to initialize ShardSegment. MRMInitSegmentError: If failed to initialize ShardSegment.
""" """
def __init__(self, file_name, num_consumer=4): def __init__(self, file_name, num_consumer=4):
check_filename(file_name) if isinstance(file_name, list):
self._file_name = file_name for f in file_name:
check_filename(f)
else:
check_filename(file_name)
if num_consumer is not None: if num_consumer is not None:
if isinstance(num_consumer, int): if isinstance(num_consumer, int):
......
...@@ -35,7 +35,7 @@ class ShardReader: ...@@ -35,7 +35,7 @@ class ShardReader:
Open file and prepare to read MindRecord File. Open file and prepare to read MindRecord File.
Args: Args:
file_name (str): File name of MindRecord File. file_name (str, list[str]): File names of MindRecord File.
num_consumer (int): Number of worker threads which load data in parallel. Default: 4. num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
columns (list[str]): List of fields which correspond data would be read. columns (list[str]): List of fields which correspond data would be read.
operator(int): Reserved parameter for operators. Default: None. operator(int): Reserved parameter for operators. Default: None.
...@@ -48,7 +48,12 @@ class ShardReader: ...@@ -48,7 +48,12 @@ class ShardReader:
""" """
columns = columns if columns else [] columns = columns if columns else []
operator = operator if operator else [] operator = operator if operator else []
ret = self._reader.open(file_name, num_consumer, columns, operator) if isinstance(file_name, list):
load_dataset = False
else:
load_dataset = True
file_name = [file_name]
ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
if ret != ms.MSRStatus.SUCCESS: if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to open {}.".format(file_name)) logger.error("Failed to open {}.".format(file_name))
raise MRMOpenError raise MRMOpenError
......
...@@ -40,7 +40,7 @@ class ShardSegment: ...@@ -40,7 +40,7 @@ class ShardSegment:
Initialize the ShardSegment. Initialize the ShardSegment.
Args: Args:
file_name (str): File name of MindRecord File. file_name (str, list[str]): File names of MindRecord File.
num_consumer (int): Number of worker threads which load data in parallel. Default: 4. num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
columns (list[str]): List of fields which correspond data would be read. columns (list[str]): List of fields which correspond data would be read.
operator(int): Reserved parameter for operators. Default: None. operator(int): Reserved parameter for operators. Default: None.
...@@ -53,7 +53,12 @@ class ShardSegment: ...@@ -53,7 +53,12 @@ class ShardSegment:
""" """
self._columns = columns if columns else [] self._columns = columns if columns else []
operator = operator if operator else [] operator = operator if operator else []
ret = self._segment.open(file_name, num_consumer, self._columns, operator) if isinstance(file_name, list):
load_dataset = False
else:
load_dataset = True
file_name = [file_name]
ret = self._segment.open(file_name, load_dataset, num_consumer, self._columns, operator)
if ret != SUCCESS: if ret != SUCCESS:
logger.error("Failed to open {}.".format(file_name)) logger.error("Failed to open {}.".format(file_name))
raise MRMOpenError raise MRMOpenError
......
...@@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) { ...@@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) {
std::shared_ptr<MindRecordOp> my_mindrecord_op; std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder; MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3) .SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4) .SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list); .SetColumnsToLoad(column_list);
...@@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) { ...@@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) {
std::shared_ptr<MindRecordOp> my_mindrecord_op; std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder; MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3) .SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4) .SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list) .SetColumnsToLoad(column_list)
...@@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) { ...@@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) {
std::shared_ptr<MindRecordOp> my_mindrecord_op; std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder; MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3) .SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4) .SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list) .SetColumnsToLoad(column_list)
...@@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) { ...@@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) {
std::shared_ptr<MindRecordOp> my_mindrecord_op; std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder; MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3) .SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4) .SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list) .SetColumnsToLoad(column_list)
...@@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) { ...@@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) {
std::shared_ptr<MindRecordOp> my_mindrecord_op; std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder; MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3) .SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4) .SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list); .SetColumnsToLoad(column_list);
...@@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) { ...@@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
std::shared_ptr<MindRecordOp> my_mindrecord_op; std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder; MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3) .SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4) .SetNumMindRecordWorkers(4)
.SetBlockReader() .SetBlockReader()
...@@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) { ...@@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) {
std::shared_ptr<MindRecordOp> my_mindrecord_op; std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder; MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3) .SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4) .SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list); .SetColumnsToLoad(column_list);
......
...@@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) { ...@@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
std::vector<std::shared_ptr<ShardOperator>> ops; std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardSample>(kSampleCount)); ops.push_back(std::make_shared<ShardSample>(kSampleCount));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) { ...@@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
ops.push_back(std::make_shared<ShardSample>(kNum, kDen)); ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) { ...@@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
ops.push_back(std::make_shared<ShardSample>(kNum, kDen)); ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { ...@@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
ASSERT_TRUE(partitions.second == 2); ASSERT_TRUE(partitions.second == 2);
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { ...@@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
ops.push_back(std::make_shared<ShardPkSample>("label", 2)); ops.push_back(std::make_shared<ShardPkSample>("label", 2));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name},true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { ...@@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0)); ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name},true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) { ...@@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
ops.push_back(std::make_shared<ShardCategory>(categories)); ops.push_back(std::make_shared<ShardCategory>(categories));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) { ...@@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
ops.push_back(std::make_shared<ShardShuffle>(1)); ops.push_back(std::make_shared<ShardShuffle>(1));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 16, column_list, ops); dataset.Open({file_name}, true, 16, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) { ...@@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
ops.push_back(std::make_shared<ShardShuffle>(1)); ops.push_back(std::make_shared<ShardShuffle>(1));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) { ...@@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
ops.push_back(std::make_shared<ShardSample>(kSampleSize)); ops.push_back(std::make_shared<ShardSample>(kSampleSize));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) { ...@@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
ops.push_back(std::make_shared<ShardSample>(35)); ops.push_back(std::make_shared<ShardSample>(35));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) { ...@@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
ops.push_back(std::make_shared<ShardShuffle>(1)); ops.push_back(std::make_shared<ShardShuffle>(1));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
ShardReader compare_dataset; ShardReader compare_dataset;
compare_dataset.Open(file_name, 4, column_list); compare_dataset.Open({file_name},true, 4, column_list);
compare_dataset.Launch(); compare_dataset.Launch();
int i = 0; int i = 0;
...@@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) { ...@@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
ops.push_back(std::make_shared<ShardShuffle>(21)); ops.push_back(std::make_shared<ShardShuffle>(21));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) { ...@@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
ops.push_back(std::make_shared<ShardCategory>(categories)); ops.push_back(std::make_shared<ShardCategory>(categories));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) { ...@@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
ops.push_back(std::make_shared<ShardCategory>(categories)); ops.push_back(std::make_shared<ShardCategory>(categories));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name},true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
...@@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) { ...@@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
ops.push_back(std::make_shared<ShardShuffle>(100)); ops.push_back(std::make_shared<ShardShuffle>(100));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
int i = 0; int i = 0;
......
...@@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { ...@@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
auto column_list = std::vector<std::string>{"file_name"}; auto column_list = std::vector<std::string>{"file_name"};
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list); dataset.Open({file_name}, true, 4, column_list);
dataset.Launch(); dataset.Launch();
while (true) { while (true) {
...@@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) { ...@@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
std::vector<std::shared_ptr<ShardOperator>> ops; std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardSample>(17)); ops.push_back(std::make_shared<ShardSample>(17));
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops); dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch(); dataset.Launch();
while (true) { while (true) {
...@@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) { ...@@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) {
ops.push_back(std::make_shared<ShardSample>(3)); ops.push_back(std::make_shared<ShardSample>(3));
ShardReader dataset; ShardReader dataset;
const bool kBlockReader = true; const bool kBlockReader = true;
dataset.Open(file_name, 4, column_list, ops, kBlockReader); dataset.Open({file_name}, true, 4, column_list, ops, kBlockReader);
dataset.Launch(); dataset.Launch();
while (true) { while (true) {
...@@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) { ...@@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
MS_LOG(INFO) << FormatInfo("Test read imageNet"); MS_LOG(INFO) << FormatInfo("Test read imageNet");
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name); dataset.Open({file_name}, true);
dataset.Launch(); dataset.Launch();
while (true) { while (true) {
...@@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { ...@@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label"}; auto column_list = std::vector<std::string>{"label"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(file_name, 4, column_list); MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
dataset.Launch(); dataset.Launch();
...@@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { ...@@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"file_namex"}; auto column_list = std::vector<std::string>{"file_namex"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(file_name, 4, column_list); MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST); ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST);
} }
...@@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) { ...@@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) {
MS_LOG(INFO) << FormatInfo("Test shard version"); MS_LOG(INFO) << FormatInfo("Test shard version");
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(file_name, 4); MSRStatus ret = dataset.Open({file_name}, true, 4);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
dataset.Launch(); dataset.Launch();
...@@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) { ...@@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) {
auto column_list = std::vector<std::string>{"file_name"}; auto column_list = std::vector<std::string>{"file_name"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(file_name, 4, column_list); MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, FAILED); ASSERT_EQ(ret, FAILED);
} }
...@@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) { ...@@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
auto column_list = std::vector<std::string>{"file_name"}; auto column_list = std::vector<std::string>{"file_name"};
ShardReader dataset; ShardReader dataset;
dataset.Open(file_name, -481565535, column_list); dataset.Open({file_name}, true, -481565535, column_list);
dataset.Launch(); dataset.Launch();
while (true) { while (true) {
......
...@@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) { ...@@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
ShardSegment dataset; ShardSegment dataset;
dataset.Open(file_name, 4); dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields(); auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) { for (const auto &fields : x.second) {
...@@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { ...@@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
ShardSegment dataset; ShardSegment dataset;
dataset.Open(file_name, 4); dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields(); auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) { for (const auto &fields : x.second) {
...@@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { ...@@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
ShardSegment dataset; ShardSegment dataset;
dataset.Open(file_name, 4); dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields(); auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) { for (const auto &fields : x.second) {
...@@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { ...@@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
ShardSegment dataset; ShardSegment dataset;
dataset.Open(file_name, 4); dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields(); auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) { for (const auto &fields : x.second) {
...@@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { ...@@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
ShardSegment dataset; ShardSegment dataset;
dataset.Open(file_name, 4); dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields(); auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) { for (const auto &fields : x.second) {
......
...@@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { ...@@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
std::string filename = "./OneSample.shard01"; std::string filename = "./OneSample.shard01";
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4); MSRStatus ret = dataset.Open({filename}, true, 4);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
dataset.Launch(); dataset.Launch();
...@@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { ...@@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
filename = "./imagenet.shard01"; filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "file_name", "data"}; auto column_list = std::vector<std::string>{"label", "file_name", "data"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4, column_list); MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
dataset.Launch(); dataset.Launch();
...@@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) { ...@@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
filename = "./imagenet.shard01"; filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "file_name"}; auto column_list = std::vector<std::string>{"label", "file_name"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4, column_list); MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
dataset.Launch(); dataset.Launch();
...@@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { ...@@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
filename = "./imagenet.shard01"; filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "data"}; auto column_list = std::vector<std::string>{"label", "data"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4, column_list); MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
dataset.Launch(); dataset.Launch();
...@@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { ...@@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
filename = "./TenSampleFortyShard.shard01"; filename = "./TenSampleFortyShard.shard01";
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4); MSRStatus ret = dataset.Open({filename}, true, 4);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
dataset.Launch(); dataset.Launch();
......
...@@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter ...@@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter
FILES_NUM = 4 FILES_NUM = 4
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord"
CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord"
CV_DIR_NAME = "../data/mindrecord/testImageNetData" CV_DIR_NAME = "../data/mindrecord/testImageNetData"
NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
...@@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial(): ...@@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial():
os.remove("{}".format(x)) os.remove("{}".format(x))
os.remove("{}.db".format(x)) os.remove("{}.db".format(x))
def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
"""tutorial for cv minddataset.""" """tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
...@@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem ...@@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem
assert num_iter == 20 assert num_iter == 20
def test_cv_minddataset_reader_file_list(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers)
assert data_set.get_dataset_size() == 10
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
data_set = ds.MindDataset([CV_FILE_NAME + "0"], columns_list, num_readers)
assert data_set.get_dataset_size() < 10
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter < 10
def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
if os.path.exists(CV1_FILE_NAME):
os.remove(CV1_FILE_NAME)
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
os.remove("{}.db".format(CV1_FILE_NAME))
if os.path.exists(CV2_FILE_NAME):
os.remove(CV2_FILE_NAME)
if os.path.exists("{}.db".format(CV2_FILE_NAME)):
os.remove("{}.db".format(CV2_FILE_NAME))
writer = FileWriter(CV1_FILE_NAME, 1)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "CV1_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
writer = FileWriter(CV2_FILE_NAME, 1)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "CV2_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
columns_list = ["data", "file_name", "label"]
num_readers = 4
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME], columns_list, num_readers)
assert data_set.get_dataset_size() == 30
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 30
if os.path.exists(CV1_FILE_NAME):
os.remove(CV1_FILE_NAME)
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
os.remove("{}.db".format(CV1_FILE_NAME))
if os.path.exists(CV2_FILE_NAME):
os.remove(CV2_FILE_NAME)
if os.path.exists("{}.db".format(CV2_FILE_NAME)):
os.remove("{}.db".format(CV2_FILE_NAME))
def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "CV1_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
columns_list = ["data", "file_name", "label"]
num_readers = 4
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + [CV1_FILE_NAME + str(x) for x in range(2, 4)], columns_list, num_readers)
assert data_set.get_dataset_size() < 20
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter < 20
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
"""tutorial for cv minderdataset.""" """tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
......
...@@ -22,6 +22,7 @@ import mindspore.dataset as ds ...@@ -22,6 +22,7 @@ import mindspore.dataset as ds
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
CV_FILE_NAME = "./imagenet.mindrecord" CV_FILE_NAME = "./imagenet.mindrecord"
CV1_FILE_NAME = "./imagenet1.mindrecord"
def create_cv_mindrecord(files_num): def create_cv_mindrecord(files_num):
...@@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num): ...@@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num):
writer.commit() writer.commit()
def create_diff_schema_cv_mindrecord(files_num):
"""tutorial for cv dataset writer."""
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
writer = FileWriter(CV1_FILE_NAME, files_num)
cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name_1", "label"])
writer.write_raw_data(data)
writer.commit()
def create_diff_page_size_cv_mindrecord(files_num):
"""tutorial for cv dataset writer."""
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
writer = FileWriter(CV1_FILE_NAME, files_num)
writer.set_page_size(1<< 26) #64MB
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
def test_cv_lack_json(): def test_cv_lack_json():
"""tutorial for cv minderdataset.""" """tutorial for cv minderdataset."""
create_cv_mindrecord(1) create_cv_mindrecord(1)
...@@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle(): ...@@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_minddataset_reader_different_schema():
create_cv_mindrecord(1)
create_diff_schema_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
os.remove(CV1_FILE_NAME)
os.remove("{}.db".format(CV1_FILE_NAME))
def test_cv_minddataset_reader_different_page_size():
create_cv_mindrecord(1)
create_diff_page_size_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
os.remove(CV1_FILE_NAME)
os.remove("{}.db".format(CV1_FILE_NAME))
...@@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial(): ...@@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial():
assert count == 10 assert count == 10
reader.close() reader.close()
def test_cv_file_reader_file_list():
"""tutorial for cv file partial reader."""
reader = FileReader([CV_FILE_NAME + str(x) for x in range(FILES_NUM)])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 10
def test_cv_file_reader_partial_tutorial(): def test_cv_file_reader_partial_tutorial():
"""tutorial for cv file partial reader.""" """tutorial for cv file partial reader."""
reader = FileReader(CV_FILE_NAME + "0") reader = FileReader(CV_FILE_NAME + "0")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册