提交 aa3f89e7 编写于 作者: L liyong

mindrecord support read file list

上级 a2d5ad5a
......@@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
}
std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>();
(void)builder->SetDatasetFile(ToString(args["dataset_file"]));
bool load_dataset = ToBool(args["load_dataset"]);
if (load_dataset == true) {
(void)builder->SetDatasetFile({ToString(args["dataset_file"])});
} else {
(void)builder->SetDatasetFile(ToStringVector(args["dataset_file"]));
}
(void)builder->SetLoadDataset(load_dataset);
std::vector<std::string> in_col_names;
if (!args["columns_list"].is_none()) {
in_col_names = ToStringVector(args["columns_list"]);
......
......@@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) {
});
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
.def_static("get_num_rows", [](const std::string &path, const py::object &sampler) {
int64_t count = 0;
std::shared_ptr<mindrecord::ShardOperator> op;
if (py::hasattr(sampler, "_create_for_minddataset")) {
auto create = sampler.attr("_create_for_minddataset");
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
}
THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count));
return count;
});
.def_static("get_num_rows",
[](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler) {
int64_t count = 0;
std::shared_ptr<mindrecord::ShardOperator> op;
if (py::hasattr(sampler, "_create_for_minddataset")) {
auto create = sampler.attr("_create_for_minddataset");
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
}
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count));
return count;
});
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
.def_static("get_num_rows_and_classes",
......
......@@ -40,7 +40,7 @@ using mindrecord::ShardOperator;
using mindrecord::ShardReader;
// Builder constructor. Creates the builder object.
MindRecordOp::Builder::Builder() : build_dataset_file_("") {
MindRecordOp::Builder::Builder() : build_dataset_file_({}) {
// Some arguments to the MindRecordOp constructor have a default argument that is taken
// from the client config.
// The user may choose to change these values for the construction of the StorageOp by
......@@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
"Building a MindRecordOp that has not provided a file.");
}
new_mind_record_op = std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_rows_per_buffer_,
build_dataset_file_, build_op_connector_queue_size_,
build_columns_to_load_, build_operators_, build_block_reader_);
new_mind_record_op = std::make_shared<MindRecordOp>(
build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_,
build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_);
RETURN_IF_NOT_OK(new_mind_record_op->Init());
......@@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); }
// Constructor of the MindRecordOp.
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file,
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer,
std::vector<std::string> dataset_file, bool load_dataset, int32_t op_connector_queue_size,
const std::vector<std::string> &columns_to_load,
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader)
: ParallelOp(num_mind_record_workers, op_connector_queue_size),
rows_per_buffer_(rows_per_buffer),
dataset_file_(dataset_file),
load_dataset_(load_dataset),
columns_to_load_(columns_to_load),
operators_(operators),
num_mind_record_workers_(num_mind_record_workers),
......@@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
// Private helper method to encapsulate some common construction/reset tasks
Status MindRecordOp::Init() {
shard_reader_ = std::make_unique<ShardReader>();
auto rc = shard_reader_->Open(dataset_file_, num_mind_record_workers_, columns_to_load_, operators_, block_reader_);
auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_,
block_reader_);
CHECK_FAIL_RETURN_UNEXPECTED(rc != MSRStatus::FAILED,
CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS,
"MindRecordOp init failed. Error message: " + ErrnoToMessage(rc));
data_schema_ = std::make_unique<DataSchema>();
......@@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\n1 Dataset file : " << dataset_file_ << "\nNumber of rows : " << num_rows_
<< "\nRows per buffer : " << rows_per_buffer_ << "\nNumber of buffers : " << buffers_needed_
out << "\n Dataset file : ";
for (auto &file : dataset_file_) {
out << file << " ";
}
out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_
<< "\nNumber of buffers : " << buffers_needed_
<< "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n";
}
}
......@@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
return Status::OK();
}
Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
int64_t *count) {
Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count) {
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count);
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count);
if (rc == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed.");
}
......
......@@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp {
return *this;
}
Builder &SetDatasetFile(const std::string &file) {
build_dataset_file_ = file;
Builder &SetDatasetFile(const std::vector<std::string> &files) {
build_dataset_file_ = files;
return *this;
}
......@@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp {
return *this;
}
Builder &SetLoadDataset(bool load_dataset) {
build_load_dataset_ = load_dataset;
return *this;
}
Status SanityCheck() const;
static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; }
......@@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp {
int32_t builder_num_workers_;
int32_t build_rows_per_buffer_;
int32_t build_op_connector_queue_size_;
std::string build_dataset_file_;
std::vector<std::string> build_dataset_file_;
bool build_load_dataset_;
std::vector<std::string> build_columns_to_load_;
std::vector<std::shared_ptr<ShardOperator>> build_operators_;
bool build_block_reader_;
......@@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp {
// @note The builder class should be used to call it
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
// @param rows_per_buffer - The requested number of rows per buffer
// @param dataset_file - A shard file
// @param dataset_file - dataset files
// @param op_connector_queue_size - The output connector queue size
// @param columns_to_load - The list of columns to use (column name)
// @param operators - ShardOperators for Shuffle, Category, Sample
MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file,
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector<std::string> dataset_file,
bool load_dataset, int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader);
// Destructor
......@@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp {
// Getter method
int32_t num_rows() const { return num_rows_; }
// Getter method
static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
int64_t *count);
static Status CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count);
// Getter method
int32_t rows_per_buffer() const { return rows_per_buffer_; }
// Getter method
std::string dataset_file() const { return dataset_file_; }
std::vector<std::string> dataset_file() const { return dataset_file_; }
// Getter method
std::vector<std::string> columns_to_load() const { return columns_to_load_; }
bool block_reader() const { return block_reader_; }
bool load_dataset() const { return load_dataset_; }
Status Init();
Status SetColumnsBlob();
......@@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp {
Status FetchBlockBuffer(const int32_t &buffer_id);
int32_t rows_per_buffer_; // The number of requested rows per buffer.
std::string dataset_file_; // A dataset file
std::vector<std::string> dataset_file_; // dataset files
bool load_dataset_; // load dataset from single file or not
std::vector<std::string> columns_to_load_; // Columns to load from dataset
std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use
int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader
......
......@@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) {
case IO_FAILED:
return "io operate failed";
break;
case MATCH_HEADER_FAILED:
return "match header failed";
break;
default:
return "invalid error no";
}
......
......@@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) {
void BindShardReader(const py::module *m) {
(void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
.def(py::init<>())
.def("open", (MSRStatus(ShardReader::*)(const std::string &, const int &, const std::vector<std::string> &,
.def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &,
const std::vector<std::string> &,
const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardReader::OpenPy)
.def("launch", &ShardReader::Launch)
......@@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) {
void BindShardSegment(py::module *m) {
(void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
.def(py::init<>())
.def("open", (MSRStatus(ShardSegment::*)(const std::string &, const int &, const std::vector<std::string> &,
.def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &,
const std::vector<std::string> &,
const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardSegment::OpenPy)
.def("get_category_fields",
......
......@@ -72,7 +72,8 @@ enum MSRStatus {
ILLEGAL_PARAMETERS,
GET_PAGE_BY_GROUP_ID_FAILED,
GET_SYSTEM_STATE_FAILED,
IO_FAILED
IO_FAILED,
MATCH_HEADER_FAILED
};
// convert error no to string message
......
......@@ -35,10 +35,11 @@ class ShardHeader {
public:
ShardHeader();
MSRStatus Build(const std::string &file_path);
~ShardHeader() = default;
MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);
static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path);
/// \brief add the schema and save it
/// \param[in] schema the schema needs to be added
/// \return the last schema's id
......@@ -126,7 +127,7 @@ class ShardHeader {
MSRStatus FileToPages(const std::string dump_file_name);
private:
MSRStatus InitializeHeader(const std::vector<json> &headers);
MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);
/// \brief get the headers from all the shard data
/// \param[in] the shard data real path
......@@ -137,9 +138,9 @@ class ShardHeader {
MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
/// \brief check the binary file status
MSRStatus CheckFileStatus(const std::string &path);
static MSRStatus CheckFileStatus(const std::string &path);
std::pair<MSRStatus, json> ValidateHeader(const std::string &path);
static std::pair<MSRStatus, json> ValidateHeader(const std::string &path);
void ParseHeader(const json &header);
......@@ -149,7 +150,7 @@ class ShardHeader {
MSRStatus CheckIndexField(const std::string &field, const json &schema);
void ParsePage(const json &page);
void ParsePage(const json &page, int shard_index, bool load_dataset);
MSRStatus ParseStatistics(const json &statistics);
......
......@@ -68,23 +68,25 @@ class ShardReader {
virtual ~ShardReader();
/// \brief open files and initialize reader, c++ API
/// \param[in] file_path the path of ONE file, any file in dataset is fine
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \param[in] block_reader block-reader mode if true, otherwise row-reader mode
/// \return MSRStatus the status of MSRStatus
MSRStatus Open(const std::string &file_path, int n_consumer = 4,
MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const bool &block_reader = false);
/// \brief open files and initialize reader, python API
/// \param[in] file_path the path of ONE file, any file in dataset is fine
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \return MSRStatus the status of MSRStatus
MSRStatus OpenPy(const std::string &file_path, const int &n_consumer = 4,
MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {});
......@@ -114,11 +116,13 @@ class ShardReader {
int GetShardCount() const;
/// \brief get the number of rows in database
/// \param[in] file_path the path of ONE file, any file in dataset is fine
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
/// \param[out] count # of rows
/// \return MSRStatus the status of MSRStatus
MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, int64_t *count);
MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count);
/// \brief shuffle task with incremental seed
/// \return void
......@@ -220,7 +224,7 @@ class ShardReader {
std::vector<std::vector<json>> &column_values);
/// \brief initialize reader
MSRStatus Init(const std::string &file_path);
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
/// \brief validate column list
MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns);
......@@ -292,8 +296,9 @@ class ShardReader {
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
/// \brief get number of classes
int64_t GetNumClasses(const std::string &file_path, const std::string &category_field);
int64_t GetNumClasses(const std::string &category_field);
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
/// \brief get exactly blob fields data by indices
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
std::vector<uint32_t> &ordered_selected_columns_index);
......
......@@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
write_success_(true) {}
MSRStatus ShardIndexGenerator::Build() {
auto ret = ShardHeader::BuildSingleHeader(file_path_);
if (ret.first != SUCCESS) {
return FAILED;
}
auto json_header = ret.second;
auto ret2 = GetParentDir(file_path_);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : json_header["shard_addresses"]) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
ShardHeader header = ShardHeader();
if (header.Build(file_path_) != SUCCESS) {
MS_LOG(ERROR) << "Build shard schema failed.";
if (header.BuildDataset(real_addresses) == FAILED) {
return FAILED;
}
shard_header_ = header;
......
......@@ -47,20 +47,55 @@ ShardReader::ShardReader() {
block_reader_ = false;
}
MSRStatus ShardReader::Init(const std::string &file_path) {
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
if (!IsLegalFile(file_path)) {
return {FAILED, {}};
}
auto ret = ShardHeader::BuildSingleHeader(file_path);
if (ret.first != SUCCESS) {
return {FAILED, {}};
}
auto header = ret.second;
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
{"version", header["version"]}, {"index_fields", header["index_fields"]},
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
return {SUCCESS, header["shard_addresses"]};
}
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
std::string file_path = file_paths[0];
json first_meta_data = json();
auto ret = GetMeta(file_path, first_meta_data);
if (ret.first != SUCCESS) {
return FAILED;
}
ShardHeader sh = ShardHeader();
if (sh.Build(file_path) == FAILED) {
if (file_paths.size() == 1 && load_dataset == true) {
auto ret2 = GetParentDir(file_path);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : ret.second) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
file_paths_ = real_addresses;
} else if (file_paths.size() >= 1 && load_dataset == false) {
file_paths_ = file_paths;
} else {
MS_LOG(ERROR) << "Error in parameter file_path or load_dataset.";
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(sh);
header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize();
file_paths_ = shard_header_->GetShardAddresses();
for (const auto &file : file_paths_) {
json meta_data = json();
auto ret1 = GetMeta(file, meta_data);
if (ret1.first != SUCCESS) {
return FAILED;
}
if (meta_data != first_meta_data) {
MS_LOG(ERROR) << "Mindrecord files meta information is different.";
return FAILED;
}
sqlite3 *db = nullptr;
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
......@@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
}
database_paths_.push_back(db);
}
ShardHeader sh = ShardHeader();
if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) {
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(sh);
header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize();
num_rows_ = 0;
auto row_group_summary = ReadRowGroupSummary();
for (const auto &rg : row_group_summary) {
......@@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
fs->close();
return FAILED;
}
json label_json = json::from_msgpack(label_raw);
json tmp;
if (!columns.empty()) {
......@@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() {
return SUCCESS;
}
int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) {
ShardHeader sh = ShardHeader();
if (sh.Build(file_path) == FAILED) {
return -1;
}
auto header = std::make_shared<ShardHeader>(sh);
auto file_paths = header->GetShardAddresses();
auto shard_count = file_paths.size();
auto index_fields = header->GetFields();
int64_t ShardReader::GetNumClasses(const std::string &category_field) {
auto shard_count = file_paths_.size();
auto index_fields = shard_header_->GetFields();
std::map<std::string, int64_t> map_schema_id_fields;
for (auto &field : index_fields) {
......@@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
std::set<std::string> categories;
for (int x = 0; x < shard_count; x++) {
sqlite3 *db = nullptr;
int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
if (SQLITE_OK != rc) {
MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db);
return -1;
......@@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
return categories.size();
}
MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op,
int64_t *count) {
if (Init(file_path) == FAILED) {
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count) {
if (SUCCESS != Init(file_paths, load_dataset)) {
return FAILED;
}
int64_t num_samples = num_rows_;
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
std::string category_field = category_op->GetCategoryField();
auto num_classes = GetNumClasses(file_path, category_field);
auto num_classes = GetNumClasses(category_field);
num_samples = category_op->GetNumSamples(num_rows_, num_classes);
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
num_samples = op->GetNumSamples(num_rows_, 0);
......@@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s
return SUCCESS;
}
MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) {
// Open file and set header by ShardReader
if (Init(file_path) == FAILED) {
return FAILED;
auto ret = Init(file_paths, load_dataset);
if (SUCCESS != ret) {
return ret;
}
auto thread_limit = GetMaxThreadNum();
if (n_consumer > thread_limit) {
......@@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
return SUCCESS;
}
MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consumer,
MSRStatus ShardReader::OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
// Open file and set header by ShardReader
if (Init(file_path) == FAILED) {
if (SUCCESS != Init(file_paths, load_dataset)) {
return FAILED;
}
// should remove blob field from selected_columns when call from python
......
......@@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if (!IsLegalFile(path)) {
return FAILED;
}
ShardHeader sh = ShardHeader();
if (sh.Build(path) == FAILED) {
auto ret1 = ShardHeader::BuildSingleHeader(path);
if (ret1.first != SUCCESS) {
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(sh);
auto paths = shard_header_->GetShardAddresses();
auto json_header = ret1.second;
auto ret2 = GetParentDir(path);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : json_header["shard_addresses"]) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
ShardHeader header = ShardHeader();
if (header.BuildDataset(real_addresses) == FAILED) {
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(header);
MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize());
if (ret == FAILED) {
return FAILED;
......@@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if (ret == FAILED) {
return FAILED;
}
ret = Open(paths, true);
ret = Open(json_header["shard_addresses"], true);
if (ret == FAILED) {
MS_LOG(ERROR) << "Open file failed";
return FAILED;
......
......@@ -35,8 +35,9 @@ namespace mindrecord {
std::atomic<bool> thread_status(false);
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); }
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
shard_count_ = headers.size();
int shard_index = 0;
bool first = true;
for (const auto &header : headers) {
if (first) {
......@@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) {
header_size_ = header["header_size"].get<uint64_t>();
page_size_ = header["page_size"].get<uint64_t>();
}
ParsePage(header["page"]);
ParsePage(header["page"], shard_index, load_dataset);
shard_index++;
}
return SUCCESS;
}
......@@ -136,40 +138,39 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path)
return {SUCCESS, json_header};
}
MSRStatus ShardHeader::Build(const std::string &file_path) {
std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) {
auto ret = ValidateHeader(file_path);
if (SUCCESS != ret.first) {
return FAILED;
}
json main_header = ret.second;
json addresses = main_header["shard_addresses"];
vector<string> real_addresses;
auto ret1 = GetParentDir(file_path);
if (SUCCESS != ret1.first) {
return FAILED;
return {FAILED, json()};
}
std::string parent_dir = ret1.second;
json raw_header = ret.second;
json header = {{"shard_addresses", raw_header["shard_addresses"]},
{"header_size", raw_header["header_size"]},
{"page_size", raw_header["page_size"]},
{"index_fields", raw_header["index_fields"]},
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
{"schema", raw_header["schema"][0]["schema"]},
{"version", raw_header["version"]}};
return {SUCCESS, header};
}
for (const auto &addr : addresses) {
std::string absolute_path = parent_dir + string(addr);
real_addresses.emplace_back(absolute_path);
}
MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
uint32_t thread_num = std::thread::hardware_concurrency();
if (thread_num == 0) thread_num = kThreadNumber;
uint32_t work_thread_num = 0;
uint32_t addr_count = real_addresses.size();
int group_num = ceil(addr_count * 1.0 / thread_num);
uint32_t shard_count = file_paths.size();
int group_num = ceil(shard_count * 1.0 / thread_num);
std::vector<std::thread> thread_set(thread_num);
std::vector<json> headers(addr_count);
std::vector<json> headers(shard_count);
for (uint32_t x = 0; x < thread_num; ++x) {
int start_num = x * group_num;
int end_num = ((x + 1) * group_num > addr_count) ? addr_count : (x + 1) * group_num;
int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num;
if (start_num >= end_num) {
continue;
}
thread_set[x] =
std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), real_addresses);
std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths);
work_thread_num++;
}
......@@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) {
thread_status = false;
return FAILED;
}
if (SUCCESS != InitializeHeader(headers)) {
if (SUCCESS != InitializeHeader(headers, load_dataset)) {
return FAILED;
}
return SUCCESS;
......@@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
return SUCCESS;
}
void ShardHeader::ParsePage(const json &pages) {
void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
// set shard_index when load_dataset is false
if (pages_.empty() && shard_count_ <= kMaxShardCount) {
pages_.resize(shard_count_);
}
......@@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) {
std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
end_row_id, row_group_ids, page_size);
pages_[shard_id].push_back(std::move(parsed_page));
if (load_dataset == true) {
pages_[shard_id].push_back(std::move(parsed_page));
} else {
pages_[shard_index].push_back(std::move(parsed_page));
}
}
}
......@@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
std::string line;
while (std::getline(page_in_handle, line)) {
ParsePage(json::parse(line));
ParsePage(json::parse(line), -1, true);
}
page_in_handle.close();
......
......@@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset):
A source dataset that reads from shard files and database.
Args:
dataset_file (str): one of file names in dataset.
dataset_file (str, list[str]): One of file names or file list in dataset.
columns_list (list[str], optional): List of columns to be read (default=None).
num_parallel_workers (int, optional): The number of readers (default=None).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
......@@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset):
shuffle=None, num_shards=None, shard_id=None,
block_reader=False, sampler=None):
super().__init__(num_parallel_workers)
if isinstance(dataset_file, list):
self.load_dataset = False
else:
self.load_dataset = True
self.dataset_file = dataset_file
self.columns_list = columns_list
self.global_shuffle = shuffle
......@@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset):
def get_args(self):
args = super().get_args()
args["dataset_file"] = self.dataset_file
args["load_dataset"] = self.load_dataset
args["columns_list"] = self.columns_list
args["global_shuffle"] = self.global_shuffle
args["partitions"] = self.partitions
......@@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset):
Return:
Number, number of batches.
"""
num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler)
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler)
if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0]
......
......@@ -529,8 +529,11 @@ def check_minddataset(method):
dataset_file = param_dict.get('dataset_file')
if dataset_file is None:
raise ValueError("dataset_file is not provided.")
check_dataset_file(dataset_file)
if isinstance(dataset_file, list):
for f in dataset_file:
check_dataset_file(f)
else:
check_dataset_file(dataset_file)
check_param_type(nreq_param_int, param_dict, int)
check_param_type(nreq_param_list, param_dict, list)
......
......@@ -28,7 +28,7 @@ class FileReader:
Class to read MindRecord File series.
Args:
file_name (str): File name of MindRecord File.
file_name (str, list[str]): One of MindRecord File or file list.
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
It should not be smaller than 1 or larger than the number of CPU.
columns (list[str], optional): List of fields which correspond data would be read (default=None).
......@@ -38,8 +38,11 @@ class FileReader:
ParamValueError: If file_name, num_consumer or columns is invalid.
"""
def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
check_filename(file_name)
self._file_name = file_name
if isinstance(file_name, list):
for f in file_name:
check_filename(f)
else:
check_filename(file_name)
if num_consumer is not None:
if isinstance(num_consumer, int):
......
......@@ -28,7 +28,7 @@ class MindPage:
Class to read MindRecord File series in pagination.
Args:
file_name (str): File name of MindRecord File.
file_name (str): One of MindRecord File or file list.
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
It should not be smaller than 1 or larger than the number of CPU.
......@@ -37,8 +37,11 @@ class MindPage:
MRMInitSegmentError: If failed to initialize ShardSegment.
"""
def __init__(self, file_name, num_consumer=4):
check_filename(file_name)
self._file_name = file_name
if isinstance(file_name, list):
for f in file_name:
check_filename(f)
else:
check_filename(file_name)
if num_consumer is not None:
if isinstance(num_consumer, int):
......
......@@ -35,7 +35,7 @@ class ShardReader:
Open file and prepare to read MindRecord File.
Args:
file_name (str): File name of MindRecord File.
file_name (str, list[str]): File names of MindRecord File.
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
columns (list[str]): List of fields which correspond data would be read.
operator(int): Reserved parameter for operators. Default: None.
......@@ -48,7 +48,12 @@ class ShardReader:
"""
columns = columns if columns else []
operator = operator if operator else []
ret = self._reader.open(file_name, num_consumer, columns, operator)
if isinstance(file_name, list):
load_dataset = False
else:
load_dataset = True
file_name = [file_name]
ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to open {}.".format(file_name))
raise MRMOpenError
......
......@@ -40,7 +40,7 @@ class ShardSegment:
Initialize the ShardSegment.
Args:
file_name (str): File name of MindRecord File.
file_name (str, list[str]): File names of MindRecord File.
num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
columns (list[str]): List of fields which correspond data would be read.
operator(int): Reserved parameter for operators. Default: None.
......@@ -53,7 +53,12 @@ class ShardSegment:
"""
self._columns = columns if columns else []
operator = operator if operator else []
ret = self._segment.open(file_name, num_consumer, self._columns, operator)
if isinstance(file_name, list):
load_dataset = False
else:
load_dataset = True
file_name = [file_name]
ret = self._segment.open(file_name, load_dataset, num_consumer, self._columns, operator)
if ret != SUCCESS:
logger.error("Failed to open {}.".format(file_name))
raise MRMOpenError
......
......@@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) {
std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list);
......@@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) {
std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list)
......@@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) {
std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list)
......@@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) {
std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list)
......@@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) {
std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list);
......@@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4)
.SetBlockReader()
......@@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) {
std::shared_ptr<MindRecordOp> my_mindrecord_op;
MindRecordOp::Builder builder;
builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0")
builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"})
.SetLoadDataset(true)
.SetRowsPerBuffer(3)
.SetNumMindRecordWorkers(4)
.SetColumnsToLoad(column_list);
......
......@@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardSample>(kSampleCount));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
ASSERT_TRUE(partitions.second == 2);
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
ops.push_back(std::make_shared<ShardPkSample>("label", 2));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name},true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name},true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
ops.push_back(std::make_shared<ShardCategory>(categories));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
ops.push_back(std::make_shared<ShardShuffle>(1));
ShardReader dataset;
dataset.Open(file_name, 16, column_list, ops);
dataset.Open({file_name}, true, 16, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
ops.push_back(std::make_shared<ShardShuffle>(1));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
ops.push_back(std::make_shared<ShardSample>(kSampleSize));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
ops.push_back(std::make_shared<ShardSample>(35));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
ops.push_back(std::make_shared<ShardShuffle>(1));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
ShardReader compare_dataset;
compare_dataset.Open(file_name, 4, column_list);
compare_dataset.Open({file_name},true, 4, column_list);
compare_dataset.Launch();
int i = 0;
......@@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
ops.push_back(std::make_shared<ShardShuffle>(21));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
ops.push_back(std::make_shared<ShardCategory>(categories));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
ops.push_back(std::make_shared<ShardCategory>(categories));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name},true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......@@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
ops.push_back(std::make_shared<ShardShuffle>(100));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
int i = 0;
......
......@@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
auto column_list = std::vector<std::string>{"file_name"};
ShardReader dataset;
dataset.Open(file_name, 4, column_list);
dataset.Open({file_name}, true, 4, column_list);
dataset.Launch();
while (true) {
......@@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardSample>(17));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Open({file_name}, true, 4, column_list, ops);
dataset.Launch();
while (true) {
......@@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) {
ops.push_back(std::make_shared<ShardSample>(3));
ShardReader dataset;
const bool kBlockReader = true;
dataset.Open(file_name, 4, column_list, ops, kBlockReader);
dataset.Open({file_name}, true, 4, column_list, ops, kBlockReader);
dataset.Launch();
while (true) {
......@@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
MS_LOG(INFO) << FormatInfo("Test read imageNet");
std::string file_name = "./imagenet.shard01";
ShardReader dataset;
dataset.Open(file_name);
dataset.Open({file_name}, true);
dataset.Launch();
while (true) {
......@@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label"};
ShardReader dataset;
MSRStatus ret = dataset.Open(file_name, 4, column_list);
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
dataset.Launch();
......@@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"file_namex"};
ShardReader dataset;
MSRStatus ret = dataset.Open(file_name, 4, column_list);
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST);
}
......@@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) {
MS_LOG(INFO) << FormatInfo("Test shard version");
std::string file_name = "./imagenet.shard01";
ShardReader dataset;
MSRStatus ret = dataset.Open(file_name, 4);
MSRStatus ret = dataset.Open({file_name}, true, 4);
ASSERT_EQ(ret, SUCCESS);
dataset.Launch();
......@@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) {
auto column_list = std::vector<std::string>{"file_name"};
ShardReader dataset;
MSRStatus ret = dataset.Open(file_name, 4, column_list);
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, FAILED);
}
......@@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
auto column_list = std::vector<std::string>{"file_name"};
ShardReader dataset;
dataset.Open(file_name, -481565535, column_list);
dataset.Open({file_name}, true, -481565535, column_list);
dataset.Launch();
while (true) {
......
......@@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) {
std::string file_name = "./imagenet.shard01";
ShardSegment dataset;
dataset.Open(file_name, 4);
dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
......@@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
std::string file_name = "./imagenet.shard01";
ShardSegment dataset;
dataset.Open(file_name, 4);
dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
......@@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
std::string file_name = "./imagenet.shard01";
ShardSegment dataset;
dataset.Open(file_name, 4);
dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
......@@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
std::string file_name = "./imagenet.shard01";
ShardSegment dataset;
dataset.Open(file_name, 4);
dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
......@@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
std::string file_name = "./imagenet.shard01";
ShardSegment dataset;
dataset.Open(file_name, 4);
dataset.Open({file_name}, true, 4);
auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
......
......@@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
std::string filename = "./OneSample.shard01";
ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4);
MSRStatus ret = dataset.Open({filename}, true, 4);
ASSERT_EQ(ret, SUCCESS);
dataset.Launch();
......@@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "file_name", "data"};
ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4, column_list);
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
dataset.Launch();
......@@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "file_name"};
ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4, column_list);
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
dataset.Launch();
......@@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "data"};
ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4, column_list);
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
dataset.Launch();
......@@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
filename = "./TenSampleFortyShard.shard01";
ShardReader dataset;
MSRStatus ret = dataset.Open(filename, 4);
MSRStatus ret = dataset.Open({filename}, true, 4);
ASSERT_EQ(ret, SUCCESS);
dataset.Launch();
......
......@@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter
FILES_NUM = 4
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord"
CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord"
CV_DIR_NAME = "../data/mindrecord/testImageNetData"
NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
......@@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial():
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
......@@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem
assert num_iter == 20
def test_cv_minddataset_reader_file_list(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers)
assert data_set.get_dataset_size() == 10
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
data_set = ds.MindDataset([CV_FILE_NAME + "0"], columns_list, num_readers)
assert data_set.get_dataset_size() < 10
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter < 10
def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
if os.path.exists(CV1_FILE_NAME):
os.remove(CV1_FILE_NAME)
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
os.remove("{}.db".format(CV1_FILE_NAME))
if os.path.exists(CV2_FILE_NAME):
os.remove(CV2_FILE_NAME)
if os.path.exists("{}.db".format(CV2_FILE_NAME)):
os.remove("{}.db".format(CV2_FILE_NAME))
writer = FileWriter(CV1_FILE_NAME, 1)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "CV1_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
writer = FileWriter(CV2_FILE_NAME, 1)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "CV2_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
columns_list = ["data", "file_name", "label"]
num_readers = 4
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME], columns_list, num_readers)
assert data_set.get_dataset_size() == 30
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 30
if os.path.exists(CV1_FILE_NAME):
os.remove(CV1_FILE_NAME)
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
os.remove("{}.db".format(CV1_FILE_NAME))
if os.path.exists(CV2_FILE_NAME):
os.remove(CV2_FILE_NAME)
if os.path.exists("{}.db".format(CV2_FILE_NAME)):
os.remove("{}.db".format(CV2_FILE_NAME))
def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "CV1_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
columns_list = ["data", "file_name", "label"]
num_readers = 4
data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + [CV1_FILE_NAME + str(x) for x in range(2, 4)], columns_list, num_readers)
assert data_set.get_dataset_size() < 20
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter < 20
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
......
......@@ -22,6 +22,7 @@ import mindspore.dataset as ds
from mindspore.mindrecord import FileWriter
CV_FILE_NAME = "./imagenet.mindrecord"
CV1_FILE_NAME = "./imagenet1.mindrecord"
def create_cv_mindrecord(files_num):
......@@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num):
writer.commit()
def create_diff_schema_cv_mindrecord(files_num):
"""tutorial for cv dataset writer."""
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
writer = FileWriter(CV1_FILE_NAME, files_num)
cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name_1", "label"])
writer.write_raw_data(data)
writer.commit()
def create_diff_page_size_cv_mindrecord(files_num):
"""tutorial for cv dataset writer."""
os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None
os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None
writer = FileWriter(CV1_FILE_NAME, files_num)
writer.set_page_size(1<< 26) #64MB
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
def test_cv_lack_json():
"""tutorial for cv minderdataset."""
create_cv_mindrecord(1)
......@@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle():
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_minddataset_reader_different_schema():
create_cv_mindrecord(1)
create_diff_schema_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
os.remove(CV1_FILE_NAME)
os.remove("{}.db".format(CV1_FILE_NAME))
def test_cv_minddataset_reader_different_page_size():
create_cv_mindrecord(1)
create_diff_page_size_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
os.remove(CV1_FILE_NAME)
os.remove("{}.db".format(CV1_FILE_NAME))
......@@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial():
assert count == 10
reader.close()
def test_cv_file_reader_file_list():
"""tutorial for cv file partial reader."""
reader = FileReader([CV_FILE_NAME + str(x) for x in range(FILES_NUM)])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 10
def test_cv_file_reader_partial_tutorial():
"""tutorial for cv file partial reader."""
reader = FileReader(CV_FILE_NAME + "0")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册