提交 b091f74c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3016 Add CSV dataset loader

Merge pull request !3016 from jiangzhiwen/dataset/csv
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "minddata/dataset/engine/datasetops/source/celeba_op.h" #include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h" #include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h" #include "minddata/dataset/engine/datasetops/source/manifest_op.h"
...@@ -88,6 +89,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = { ...@@ -88,6 +89,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kBuildVocab, &DEPipeline::ParseBuildVocabOp}, {kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}, {kClue, &DEPipeline::ParseClueOp},
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}, {kEpochCtrl, &DEPipeline::ParseEpochCtrlOp},
{kCsv, &DEPipeline::ParseCsvOp},
{kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}}; {kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) { DEPipeline::DEPipeline() : iterator_(nullptr) {
...@@ -1848,6 +1850,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num ...@@ -1848,6 +1850,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num
return Status::OK(); return Status::OK();
} }
Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>();
if (!args["dataset_files"].is_none()) {
files_list = ToStringVector(args["dataset_files"]);
(void)builder->SetCsvFilesList(files_list);
} else {
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
}
// Optional arguments
bool shuffle_required = false;
int64_t num_devices = 0;
std::vector<std::string> col_names;
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
shuffle_required = ToBool(value);
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_shards") {
num_devices = ToInt(value);
(void)builder->SetNumDevices(num_devices);
} else if (key == "shard_id") {
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "field_delim") {
(void)builder->SetFieldDelim(ToString(value)[0]);
} else if (key == "column_defaults") {
py::list py_object_list = py::reinterpret_borrow<py::list>(value);
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
for (auto l : py_object_list) {
std::string type_s = (std::string)py::str(l.get_type().attr("__name__"));
if (type_s == "int") {
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, ToInt(l)));
} else if (type_s == "float") {
column_default_list.push_back(std::make_shared<CsvOp::Record<float>>(CsvOp::FLOAT, ToFloat(l)));
} else if (type_s == "str") {
column_default_list.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, ToString(l)));
} else {
RETURN_STATUS_UNEXPECTED("Record type is not allowed");
}
}
(void)builder->SetColumDefault(column_default_list);
} else if (key == "column_names") {
col_names = ToStringVector(value);
(void)builder->SetColumName(col_names);
}
}
}
std::shared_ptr<CsvOp> csv_op;
RETURN_IF_NOT_OK(builder->Build(&csv_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op));
*top = csv_op;
if (shuffle_required) {
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t shuffle_size = 0;
int64_t num_rows = 0;
// First, get the number of rows in the dataset and then compute the shuffle size
RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(files_list, col_names.empty(), &num_rows));
RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size));
// Add the shuffle op over top of this op and return the subtree (top/bottom) to caller
RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, csv_op, &shuffle_op));
*top = shuffle_op;
*bottom = csv_op;
}
return Status::OK();
}
// Helper function to inject a shuffle operator over top of the current operation being built. // Helper function to inject a shuffle operator over top of the current operation being built.
Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op, Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *shuffle_op) { std::shared_ptr<DatasetOp> *shuffle_op) {
......
...@@ -73,6 +73,7 @@ enum OpName { ...@@ -73,6 +73,7 @@ enum OpName {
kClue, kClue,
kEpochCtrl, kEpochCtrl,
kSentencePieceVocab, kSentencePieceVocab,
kCsv
}; };
// The C++ binder class that we expose to the python script. // The C++ binder class that we expose to the python script.
...@@ -201,6 +202,8 @@ class DEPipeline { ...@@ -201,6 +202,8 @@ class DEPipeline {
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
private: private:
// Execution tree that links the dataset operators. // Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_; std::shared_ptr<ExecutionTree> tree_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h" #include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/io_block.h"
...@@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) { ...@@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) {
return count; return count;
}); });
(void)py::class_<CsvOp, DatasetOp, std::shared_ptr<CsvOp>>(*m, "CsvOp")
.def_static("get_num_rows", [](const py::list &files, bool csv_header) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
}
THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count));
return count;
});
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows", .def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode, [](const std::string &dir, const std::string &task_type, const std::string &task_mode,
...@@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) { ...@@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
.value("CELEBA", OpName::kCelebA) .value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile) .value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue) .value("EPOCHCTRL", OpName::kEpochCtrl)
.value("EPOCHCTRL", OpName::kEpochCtrl); .value("CSV", OpName::kCsv)
.value("CLUE", OpName::kClue);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix) .value("DE_JIEBA_MIX", JiebaMode::kMix)
......
...@@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES ...@@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
celeba_op.cc celeba_op.cc
text_file_op.cc text_file_op.cc
clue_op.cc clue_op.cc
csv_op.cc
) )
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
...@@ -29,4 +30,4 @@ if (ENABLE_PYTHON) ...@@ -29,4 +30,4 @@ if (ENABLE_PYTHON)
) )
endif() endif()
add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES})
\ No newline at end of file
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include <fstream>
#include <iomanip>
#include <stdexcept>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
CsvOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
builder_rows_per_buffer_ = config_manager->rows_per_buffer();
builder_worker_connector_size_ = config_manager->worker_connector_size();
}
Status CsvOp::Builder::ValidateInputs() const {
std::string err;
err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : "";
err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : "";
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err);
}
Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
RETURN_IF_NOT_OK(ValidateInputs());
// Throttle the number of workers if we have more workers than files!
if (static_cast<size_t>(builder_num_workers_) > builder_csv_files_list_.size()) {
builder_num_workers_ = builder_csv_files_list_.size();
MS_LOG(WARNING) << "CsvOp operator parallelism reduced to " << builder_num_workers_ << " workers.";
}
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(csv_op->Init());
*op = std::move(csv_op);
return Status::OK();
}
CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim),
column_default_list_(column_default),
column_name_list_(column_name),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
num_samples_(num_samples),
filename_index_(std::make_unique<StringIndex>()),
load_jagged_connector_(true),
shuffle_files_(shuffle_files),
finished_reading_dataset_(false),
num_devices_(num_device),
device_id_(device_id),
load_io_block_queue_(true) {
worker_connector_size_ = worker_connector_size;
}
Status CsvOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_));
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(csv_files_list_.size() / num_workers_) + 1);
io_block_queues_.Init(num_workers_, safe_queue_size);
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
jagged_buffer_connector_ = std::make_shared<JaggedConnector>(num_workers_, 1, worker_connector_size_);
return Status::OK();
}
int CsvOp::CsvParser::put_record(char c) {
std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_);
std::shared_ptr<Tensor> t;
switch (column_default_[cur_col_]->type) {
case CsvOp::INT:
Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32));
t->SetItemAt<int32_t>({0}, std::stoi(s));
break;
case CsvOp::FLOAT:
Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32));
t->SetItemAt<float>({0}, std::stof(s));
break;
case CsvOp::STRING:
Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar());
break;
default:
Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar());
break;
}
(*tensor_table_)[cur_row_][cur_col_] = std::move(t);
pos_ = 0;
cur_col_++;
return 0;
}
int CsvOp::CsvParser::put_row(char c) {
if (total_rows_ < start_offset_) {
total_rows_++;
cur_col_ = 0;
return 0;
}
if (total_rows_ >= end_offset_) {
return 0;
}
put_record(c);
total_rows_++;
cur_row_++;
cur_col_ = 0;
if (cur_row_ == csv_rows_per_buffer_) {
cur_buffer_->set_tensor_table(std::move(tensor_table_));
buffer_connector_->Add(worker_id_, std::move(cur_buffer_));
cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
tensor_table_ = std::make_unique<TensorQTable>();
cur_row_ = 0;
}
return 0;
}
int CsvOp::CsvParser::end_file(char c) {
if (cur_col_ > 0) {
put_row(c);
}
if (cur_row_ > 0) {
cur_buffer_->set_tensor_table(std::move(tensor_table_));
buffer_connector_->Add(worker_id_, std::move(cur_buffer_));
}
return 0;
}
int CsvOp::CsvParser::countRows(char c) {
Message m;
if (c == '"') {
m = Message::MS_QUOTE;
} else if (c == '\r' || c == '\n' || c == std::char_traits<char>::eof()) {
m = Message::MS_END_OF_LINE;
} else {
m = Message::MS_NORMAL;
}
StateDiagram::iterator it = sdl.find({cur_state_, m});
if (it == sd.end()) {
return -1;
}
cur_state_ = it->second.first;
return it->second.second(*this, c);
}
Status CsvOp::CsvParser::initCsvParser() {
str_buf_.resize(CSV_BUFFER_SIZE);
// State diagram for counting rows
sdl = {// START_OF_FILE
// ┌───────────┬───────────┬─────────────┐
// │ abc │ " │ \n │
// ├───────────┼───────────┼─────────────┤
// │ UNQUOTE │ QUOTE │ END_OF_LINE │
// ├───────────┼───────────┼─────────────┤
// | null_func │ null_func │ null_func │
// └───────────┴───────────┴─────────────┘
{{State::START_OF_FILE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}},
{{State::START_OF_FILE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}},
{{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}},
// UNQUOTE
// ┌───────────┬───────────┬─────────────┐
// │ abc │ " │ \n │
// ├───────────┼───────────┼─────────────┤
// │ UNQUOTE │ QUOTE │ END_OF_LINE │
// ├───────────┼───────────┼─────────────┤
// | null_func │ null_func │ add_row │
// └───────────┴───────────┴─────────────┘
{{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}},
{{State::UNQUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}},
{{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}},
// QUOTE
// ┌───────────┬──────────────┬───────────┐
// │ abc │ " │ \n │
// ├───────────┼──────────────┼───────────┤
// │ QUOTE │ SECOND_QUOTE │ QUOTE │
// ├───────────┼──────────────┼───────────┤
// | null_func │ null_func │ null_func │
// └───────────┴──────────────┴───────────┘
{{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::null_func}},
{{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}},
{{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::null_func}},
// SECOND_QUOTE
// ┌───────────┬───────────┬─────────────┐
// │ abc │ " │ \n │
// ├───────────┼───────────┼─────────────┤
// │ UNQUOTE │ QUOTE │ END_OF_LINE │
// ├───────────┼───────────┼─────────────┤
// | null_func │ null_func │ add_row │
// └───────────┴───────────┴─────────────┘
{{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}},
{{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}},
{{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}},
// END_OF_LINE
// ┌───────────┬───────────┬─────────────┐
// │ abc │ " │ \n │
// ├───────────┼───────────┼─────────────┤
// │ UNQUOTE │ QUOTE │ END_OF_LINE │
// ├───────────┼───────────┼─────────────┤
// | null_func │ null_func │ null_func │
// └───────────┴───────────┴─────────────┘
{{State::END_OF_LINE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}},
{{State::END_OF_LINE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}},
{{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}};
// State diagram for CSV parser
sd = {// START_OF_FILE
// ┌───────────┬──────────┬──────────┬────────────────┬────────────────┐
// │ abc │ , │ " │ \n │ EOF │
// ├───────────┼──────────┼──────────┼────────────────┼────────────────┤
// │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │
// ├───────────┼──────────┼──────────┼────────────────┼────────────────┤
// | lambda │ lambda │ lambda │ null_func │ null_func │
// └───────────┴──────────┴──────────┴────────────────┴────────────────┘
{{State::START_OF_FILE, Message::MS_NORMAL},
{State::UNQUOTE,
[this](CsvParser &, char c) -> int {
this->tensor_table_ = std::make_unique<TensorQTable>();
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
this->str_buf_[0] = c;
this->pos_ = 1;
return 0;
}}},
{{State::START_OF_FILE, Message::MS_DELIM},
{State::DELIM,
[this](CsvParser &, char c) -> int {
this->tensor_table_ = std::make_unique<TensorQTable>();
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
this->put_record(c);
return 0;
}}},
{{State::START_OF_FILE, Message::MS_QUOTE},
{State::QUOTE,
[this](CsvParser &, char c) -> int {
this->tensor_table_ = std::make_unique<TensorQTable>();
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
this->pos_ = 0;
return 0;
}}},
{{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}},
{{State::START_OF_FILE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::null_func}},
// UNQUOTE
// ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐
// │ abc │ , │ " │ \n │ EOF │
// ├───────────┼────────────┼───────────┼─────────────┼────────────────┤
// │ UNQUOTE │ DELIM │ EXCEPTION │ END_OF_LINE │ END_OF_FILE │
// ├───────────┼────────────┼───────────┼─────────────┼────────────────┤
// | put_char │ put_record │ exception │ put_row │ end_file │
// └───────────┴────────────┴───────────┴─────────────┴────────────────┘
{{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}},
{{State::UNQUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}},
{{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}},
{{State::UNQUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}},
// UNQUOTE-Exception
{{State::UNQUOTE, Message::MS_QUOTE}, {State::EXCEPTION, &CsvParser::catch_exception}},
// DELIM
// ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐
// │ abc │ , │ " │ \n │ EOF │
// ├───────────┼────────────┼───────────┼─────────────┼────────────────┤
// │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │
// ├───────────┼────────────┼───────────┼─────────────┼────────────────┤
// | put_char │ put_record │ lambda │ put_row │ end_file │
// └───────────┴────────────┴───────────┴─────────────┴────────────────┘
{{State::DELIM, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}},
{{State::DELIM, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}},
{{State::DELIM, Message::MS_QUOTE},
{State::QUOTE,
[this](CsvParser &, char c) -> int {
this->pos_ = 0;
return 0;
}}},
{{State::DELIM, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}},
{{State::DELIM, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}},
// QUOTE
// ┌───────────┬──────────┬──────────────┬──────────┬────────────────┐
// │ abc │ , │ " │ \n │ EOF │
// ├───────────┼──────────┼──────────────┼──────────┼────────────────┤
// │ QUOTE │ QUOTE │ SECOND_QUOTE │ QUOTE │ EXCEPTION │
// ├───────────┼──────────┼──────────────┼──────────┼────────────────┤
// | put_char │ put_char │ null_func │ put_char │ exception │
// └───────────┴──────────┴──────────────┴──────────┴────────────────┘
{{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::put_char}},
{{State::QUOTE, Message::MS_DELIM}, {State::QUOTE, &CsvParser::put_char}},
{{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}},
{{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::put_char}},
// QUOTE-Exception
{{State::QUOTE, Message::MS_END_OF_FILE}, {State::EXCEPTION, &CsvParser::catch_exception}},
// SECOND_QUOTE
// ┌───────────┬────────────┬──────────┬─────────────┬────────────────┐
// │ abc │ , │ " │ \n │ EOF │
// ├───────────┼────────────┼──────────┼─────────────┼────────────────┤
// │ EXCEPTION │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │
// ├───────────┼────────────┼──────────┼─────────────┼────────────────┤
// | exception │ put_record │ put_char │ put_row │ end_file │
// └───────────┴────────────┴──────────┴─────────────┴────────────────┘
{{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::put_char}},
{{State::SECOND_QUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}},
{{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}},
{{State::SECOND_QUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}},
// SECOND_QUOTE-Exception
{{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::EXCEPTION, &CsvParser::catch_exception}},
// END_OF_LINE
// ┌─────────┬────────┬────────┬─────────────┬─────────────┐
// │ abc │ , │ " │ \n │ EOF │
// ├─────────┼────────┼────────┼─────────────┼─────────────┤
// │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │
// ├─────────┼────────┼────────┼─────────────┼─────────────┤
// | lambda │ lambda │ lambda │ null_func │ end_file │
// └─────────┴────────┴────────┴─────────────┴─────────────┘
{{State::END_OF_LINE, Message::MS_NORMAL},
{State::UNQUOTE,
[this](CsvParser &, char c) -> int {
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
this->str_buf_[0] = c;
this->pos_ = 1;
return 0;
}}},
{{State::END_OF_LINE, Message::MS_DELIM},
{State::DELIM,
[this](CsvParser &, char c) -> int {
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
this->put_record(c);
return 0;
}}},
{{State::END_OF_LINE, Message::MS_QUOTE},
{State::QUOTE,
[this](CsvParser &, char c) -> int {
this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr));
return 0;
}}},
{{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}},
{{State::END_OF_LINE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}};
return Status::OK();
}
Status CsvOp::Reset() {
load_jagged_connector_ = true;
load_io_block_queue_ = true;
RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}
Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_);
csv_parser.setStartOffset(start_offset);
csv_parser.setEndOffset(end_offset);
std::ifstream ifs;
ifs.open(file, std::ifstream::in);
if (column_name_list_.empty()) {
std::string tmp;
getline(ifs, tmp);
}
csv_parser.Reset();
try {
while (ifs.good()) {
char chr = ifs.get();
if (csv_parser.processMessage(chr) != 0) {
RETURN_STATUS_UNEXPECTED("Failed to parse CSV file " + file + ":" + std::to_string(csv_parser.total_rows_));
}
}
} catch (std::invalid_argument &ia) {
std::string err_row = std::to_string(csv_parser.total_rows_);
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", invalid argument of " + std::string(ia.what()));
} catch (std::out_of_range &oor) {
std::string err_row = std::to_string(csv_parser.total_rows_);
RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of Range error: " + std::string(oor.what()));
}
return Status::OK();
}
Status CsvOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1)));
// must be called after launching workers.
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;
while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
}
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}
Status CsvOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}
// A print method typically used for debugging
void CsvOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <CsvOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n";
for (int i = 0; i < csv_files_list_.size(); ++i) {
out << " " << csv_files_list_[i];
}
out << "\n\n";
}
}
// Pops an element from a queue in io_block_queues
Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}
// Pushes an element to a queue in io_block_queues
Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}
Status CsvOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
if (finished_reading_dataset_) {
break;
}
if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}
Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
int64_t pre_count = 0;
int64_t start_offset = 0;
int64_t end_offset = 0;
bool finish = false;
while (!finish) {
std::vector<std::pair<std::string, int64_t>> file_index;
if (!i_keys.empty()) {
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it));
}
} else {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
}
}
for (auto file_info : file_index) {
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
auto ioBlock =
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
queue_index = (queue_index + 1) % num_workers_;
}
pre_count += filename_numrows_[file_info.first];
}
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
finish = false;
} else {
finish = true;
}
}
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
return Status::OK();
}
void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
return push;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status CsvOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}
return Status::OK();
}
Status CsvOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count;
all_num_rows_ += count;
}
if (all_num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API CsvDataset. Please check file path or dataset API "
"validation first.");
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK();
}
int64_t CsvOp::CountTotalRows(const std::string &file) {
CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_);
std::ifstream ifs;
ifs.open(file, std::ifstream::in);
if (column_name_list_.empty()) {
std::string tmp;
getline(ifs, tmp);
}
csv_parser.Reset();
while (ifs.good()) {
char chr = ifs.get();
if (csv_parser.countRows(chr) != 0) {
break;
}
}
return csv_parser.total_rows_;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status CsvOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}
return Status::OK();
}
Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) {
std::shared_ptr<CsvOp> op;
*count = 0;
if (csv_header) {
RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).Build(&op));
} else {
RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).SetColumName({""}).Build(&op));
}
for (auto file : files) {
*count += op->CountTotalRows(file);
}
return Status::OK();
}
std::vector<std::string> CsvOp::split(const std::string &s, char delim) {
std::vector<std::string> res;
std::stringstream ss(s);
std::string item;
while (getline(ss, item, delim)) {
res.push_back(item);
}
return res;
}
Status CsvOp::ComputeColMap() {
// Set the column name mapping (base class field)
if (column_name_id_map_.empty()) {
if (column_name_list_.empty()) {
std::string line;
std::ifstream handle(csv_files_list_[0]);
getline(handle, line);
std::vector<std::string> col_names = split(line, field_delim_);
for (int32_t i = 0; i < col_names.size(); i++) {
column_name_id_map_[col_names[i]] = i;
}
} else {
for (int32_t i = 0; i < column_name_list_.size(); i++) {
column_name_id_map_[column_name_list_[i]] = i;
}
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <utility>
#include <limits>
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
namespace mindspore {
namespace dataset {
const size_t CSV_BUFFER_SIZE = 4096;
using StringIndex = AutoIndexObj<std::string>;
class JaggedConnector;
class CsvOp : public ParallelOp {
public:
enum RecordType : uint8_t { INT = 0, FLOAT, STRING };
struct BaseRecord {
public:
BaseRecord() = default;
explicit BaseRecord(RecordType t) : type(t) {}
virtual ~BaseRecord() {}
RecordType type;
};
template <typename T>
class Record : public BaseRecord {
public:
Record() = default;
Record(RecordType t, T v) : BaseRecord(t), value(v) {}
~Record() {}
T value;
};
// CsvParser is a class that parsing CSV file.
// We design a state machine to implement CSV syntactic analysis. It contains two state diagram,'sd' and 'sdl'.
// The 'sd' is used for parsing CSV syntactic, it's complete and complicate.
// The 'sdl' is used for counting the record rows, it's concise and it runs fast.
struct CsvParser {
public:
CsvParser() = delete;
CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim,
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default)
: worker_id_(worker_id),
buffer_connector_(connector),
csv_rows_per_buffer_(rows_per_buffer),
csv_field_delim_(field_delim),
column_default_(column_default),
cur_state_(START_OF_FILE),
pos_(0),
cur_row_(0),
cur_col_(0),
total_rows_(0),
start_offset_(0),
end_offset_(std::numeric_limits<int64_t>::max()) {
cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
initCsvParser();
}
~CsvParser() = default;
void Reset() {
cur_state_ = START_OF_FILE;
pos_ = 0;
cur_row_ = 0;
cur_col_ = 0;
}
void setStartOffset(int64_t start_offset) { start_offset_ = start_offset; }
void setEndOffset(int64_t end_offset) { end_offset_ = end_offset; }
int processMessage(char c) {
Message m = getMessage(c);
StateDiagram::iterator it = sd.find({cur_state_, m});
if (it == sd.end()) {
return -1;
}
cur_state_ = it->second.first;
return it->second.second(*this, c);
}
int countRows(char c);
Status initCsvParser();
enum State : uint8_t {
START_OF_FILE = 0,
UNQUOTE,
DELIM,
QUOTE,
SECOND_QUOTE,
END_OF_LINE,
END_OF_FILE,
EXCEPTION
};
enum Message : uint8_t {
MS_NORMAL = 0,
MS_DELIM,
MS_QUOTE,
MS_END_OF_LINE,
MS_END_OF_FILE,
};
typedef std::pair<State, Message> StateMessagePair;
typedef std::pair<State, std::function<int(CsvParser &, char)>> StateActionPair;
typedef std::map<StateMessagePair, StateActionPair> StateDiagram;
Message getMessage(char c) {
if (c == csv_field_delim_) {
return Message::MS_DELIM;
} else if (c == '"') {
return Message::MS_QUOTE;
} else if (c == '\r' || c == '\n') {
return Message::MS_END_OF_LINE;
} else if (c == std::char_traits<char>::eof()) {
return Message::MS_END_OF_FILE;
} else {
return Message::MS_NORMAL;
}
}
int null_func(char c) { return 0; }
int put_char(char c) {
if (pos_ >= str_buf_.size()) {
str_buf_.resize(str_buf_.size() * 2);
}
str_buf_[pos_] = c;
pos_++;
return 0;
}
int put_record(char c);
int put_row(char c);
int end_file(char c);
int add_row(char c) {
total_rows_++;
return 0;
}
int catch_exception(char c) {
MS_LOG(ERROR) << "Invalid syntax!";
return -1;
}
int32_t worker_id_;
std::shared_ptr<JaggedConnector> buffer_connector_;
int64_t csv_rows_per_buffer_;
const char csv_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_;
State cur_state_;
size_t pos_;
int cur_row_;
int cur_col_;
int64_t total_rows_;
int64_t start_offset_;
int64_t end_offset_;
StateDiagram sd;
StateDiagram sdl;
std::vector<char> str_buf_;
std::unique_ptr<TensorQTable> tensor_table_;
std::unique_ptr<DataBuffer> cur_buffer_;
};
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
// Checks if the inputs of the builder is valid.
// @return Status - the error code returned.
Status ValidateInputs() const;
// Create the final object.
// @param op - dataset op.
// @return - the error code return.
Status Build(std::shared_ptr<CsvOp> *op);
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumDevices(int64_t num_dev) {
builder_num_devices_ = num_dev;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetDeviceId(int64_t dev_id) {
builder_device_id_ = dev_id;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetCsvFilesList(const std::vector<std::string> &files_list) {
builder_csv_files_list_ = files_list;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetShuffleFiles(bool shuffle_files) {
builder_shuffle_files_ = shuffle_files;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetFieldDelim(char field_delim) {
builder_field_delim_ = field_delim;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetColumDefault(std::vector<std::shared_ptr<CsvOp::BaseRecord>> record_list) {
builder_column_default_list_ = record_list;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetColumName(std::vector<std::string> col_name_list) {
builder_column_name_list_ = col_name_list;
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
int32_t builder_worker_connector_size_;
std::vector<std::string> builder_csv_files_list_;
bool builder_shuffle_files_;
char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_;
};
// Constructor of CsvOp
CsvOp() = delete;
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~CsvOp() = default;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status Init();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Get total rows in files.
// @param files - all csv files.
// @param csv_header - a bool that indicates csv file include header line
// @param count - number of rows.
// @return Status - the error coed returned.
static Status CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count);
// File names getter
// @return Vector of the input file names
std::vector<std::string> FileNames() { return csv_files_list_; }
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
// @param row - the id of the row filled in the tensor table.
// @return Status - the error code returned.
Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row);
// Reads a csv file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();
// Count number of rows in each file.
// @param filename - csv file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Private function for computing the assignment of the column name map.
// @return - Status
Status ComputeColMap() override;
// Split string based on a character delimiter
// @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim);
int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> csv_files_list_;
WaitPost io_block_queue_wait_post_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
char field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_;
std::vector<std::string> column_name_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
...@@ -21,7 +21,7 @@ can also create samplers with this module to sample data. ...@@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
from .core import config from .core import config
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler WeightedRandomSampler, Sampler
from .engine.cache_client import DatasetCache from .engine.cache_client import DatasetCache
...@@ -31,5 +31,5 @@ from .engine.graphdata import GraphData ...@@ -31,5 +31,5 @@ from .engine.graphdata import GraphData
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset",
"CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler", "CocoDataset", "TextFileDataset", "CLUEDataset", "CSVDataset", "Schema", "DistributedSampler", "PKSampler",
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]
...@@ -29,7 +29,7 @@ from .samplers import * ...@@ -29,7 +29,7 @@ from .samplers import *
from ..core import config from ..core import config
__all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", __all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "CSVDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler",
"PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]
...@@ -33,7 +33,7 @@ import copy ...@@ -33,7 +33,7 @@ import copy
import numpy as np import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo MindRecordOp, TextFileOp, ClueOp, CsvOp, VOCOp, CocoOp, CBatchInfo
from mindspore._c_expression import typing from mindspore._c_expression import typing
from mindspore import log as logger from mindspore import log as logger
...@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che ...@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
...@@ -1012,7 +1012,7 @@ class Dataset: ...@@ -1012,7 +1012,7 @@ class Dataset:
if isinstance(sampler, samplers.DistributedSampler): if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id dev_id = sampler.shard_id
return "", dev_id return "", dev_id
if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)): if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset, CSVDataset)):
if output_dataset.shard_id is not None: if output_dataset.shard_id is not None:
dev_id = output_dataset.shard_id dev_id = output_dataset.shard_id
return "", dev_id return "", dev_id
...@@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset): ...@@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset):
} }
Args: Args:
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of dataset_files (str or a list of strings): String or list of files to be read or glob strings to search for
files. The list will be sorted in a lexicographical order. a pattern of files. The list will be sorted in a lexicographical order.
task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
(default=AFQMC). (default=AFQMC).
usage (str, optional): Need train, test or eval data (default="train"). usage (str, optional): Need train, test or eval data (default="train").
...@@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset): ...@@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset):
return False return False
class CSVDataset(SourceDataset):
"""
A source dataset that reads and parses CSV datasets.
Args:
dataset_files (str or a list of strings): String or list of files to be read or glob strings to search
for a pattern of files. The list will be sorted in a lexicographical order.
field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=',').
column_defaults (list, optional): List of default values for the CSV field (default=None). Each item
in the list is either a valid type (float, int, or string). If this is not provided, treats all
columns as string type.
column_names (list of string, optional): List of column names of the dataset (default=None). If this
is not provided, infers the column_names from the first row of CSV file.
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Examples:
>>> import mindspore.dataset as ds
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
>>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4'])
"""
@check_csvdataset
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
self.field_delim = field_delim
self.column_defaults = column_defaults
self.column_names = column_names
self.num_samples = num_samples
if not isinstance(shuffle, (bool, Shuffle)):
raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
if not isinstance(shuffle, Shuffle):
if shuffle:
self.shuffle_level = Shuffle.GLOBAL
self.shuffle_files = True
else:
self.shuffle_level = None
self.shuffle_files = False
else:
self.shuffle_level = shuffle
self.shuffle_files = True
self.num_shards = num_shards
self.shard_id = shard_id
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
args['field_delim'] = self.field_delim
args['column_defaults'] = self.column_defaults
args['column_names'] = self.column_names
args["num_samples"] = self.num_samples
if self.shuffle_files is not None:
args["shuffle_files"] = self.shuffle_files
args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if self._dataset_size is None:
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples is None:
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
def is_shuffled(self):
return self.shuffle_files
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return False
class TextFileDataset(SourceDataset): class TextFileDataset(SourceDataset):
""" """
A source dataset that reads and parses datasets stored on disk in text format. A source dataset that reads and parses datasets stored on disk in text format.
......
...@@ -185,6 +185,8 @@ class Iterator: ...@@ -185,6 +185,8 @@ class Iterator:
op_type = OpName.SENTENCEPIECEVOCAB op_type = OpName.SENTENCEPIECEVOCAB
elif isinstance(dataset, de.CLUEDataset): elif isinstance(dataset, de.CLUEDataset):
op_type = OpName.CLUE op_type = OpName.CLUE
elif isinstance(dataset, de.CSVDataset):
op_type = OpName.CSV
else: else:
raise ValueError("Unsupported DatasetOp") raise ValueError("Unsupported DatasetOp")
......
...@@ -787,6 +787,49 @@ def check_cluedataset(method): ...@@ -787,6 +787,49 @@ def check_cluedataset(method):
return new_method return new_method
def check_csvdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CSVDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_files; required argument
dataset_files = param_dict.get('dataset_files')
type_check(dataset_files, (str, list), "dataset files")
# check field_delim
field_delim = param_dict.get('field_delim')
type_check(field_delim, (str,), 'field delim')
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
raise ValueError("field_delim is not legal.")
# check column_defaults
column_defaults = param_dict.get('column_defaults')
if column_defaults is not None:
if not isinstance(column_defaults, list):
raise TypeError("column_defaults should be type of list.")
for item in column_defaults:
if not isinstance(item, (str, int, float)):
raise TypeError("column type is not legal.")
# check column_names: must be list of string.
column_names = param_dict.get("column_names")
if column_names is not None:
all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_textfiledataset(method): def check_textfiledataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
......
...@@ -77,6 +77,7 @@ SET(DE_UT_SRCS ...@@ -77,6 +77,7 @@ SET(DE_UT_SRCS
celeba_op_test.cc celeba_op_test.cc
take_op_test.cc take_op_test.cc
clue_op_test.cc clue_op_test.cc
csv_op_test.cc
text_file_op_test.cc text_file_op_test.cc
filter_op_test.cc filter_op_test.cc
concat_op_test.cc concat_op_test.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "minddata/dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/util/status.h"
namespace common = mindspore::common;
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestCSVOp : public UT::DatasetOpTesting {
};
TEST_F(MindDataTestCSVOp, TestCSVBasic) {
// Start with an empty execution tree
auto tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
std::shared_ptr<CsvOp> op;
CsvOp::Builder builder;
builder.SetCsvFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetNumWorkers(16)
.SetShuffleFiles(false)
.SetOpConnectorSize(2)
.SetFieldDelim(',')
.SetColumDefault(column_default_list)
.SetColumName({"col1", "col2", "col3", "col4"});
Status rc = builder.Build(&op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssignRoot(op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 3);
}
TEST_F(MindDataTestCSVOp, TestTotalRows) {
std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv";
std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv";
std::vector<std::string> files;
files.push_back(csv_file1);
int64_t total_rows = 0;
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 3);
files.clear();
files.push_back(csv_file2);
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 5);
files.clear();
files.push_back(csv_file1);
files.push_back(csv_file2);
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 8);
files.clear();
}
,"222",3,"4"""
"5",6,,"8"
9,10,"1""1",12
,,"",
,,,
a,b,c,""
a,b,c,d
大家,早上好,中午好,下午好,晚上好
col1,col2,col3,col4
a,b,c,d
\ No newline at end of file
1,2,3,4
"a","b","c
","d
e"
5,6,7,8
9,10,11,12
a,"b
",c,"d
e"
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import numpy as np
import pytest
DATA_FILE = '../data/dataset/testCSV/1.csv'
def test_csv_dataset_basic():
"""
Test CSV with repeat, skip and so on
"""
TRAIN_FILE = '../data/dataset/testCSV/1.csv'
buffer = []
data = ds.CSVDataset(
TRAIN_FILE,
column_defaults=["0", 0, 0.0, "0"],
column_names=['1', '2', '3', '4'],
shuffle=False)
data = data.repeat(2)
data = data.skip(2)
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 4
def test_csv_dataset_one_file():
data = ds.CSVDataset(
DATA_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 3
def test_csv_dataset_all_file():
APPEND_FILE = '../data/dataset/testCSV/2.csv'
data = ds.CSVDataset(
[DATA_FILE, APPEND_FILE],
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 10
def test_csv_dataset_num_samples():
data = ds.CSVDataset(
DATA_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False, num_samples=2)
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 2
def test_csv_dataset_distribution():
TEST_FILE = '../data/dataset/testCSV/1.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False, num_shards=2, shard_id=0)
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 2
def test_csv_dataset_quoted():
TEST_FILE = '../data/dataset/testCSV/quoted.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_separated():
TEST_FILE = '../data/dataset/testCSV/separated.csv'
data = ds.CSVDataset(
TEST_FILE,
field_delim='|',
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_embedded():
TEST_FILE = '../data/dataset/testCSV/embedded.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a,b', 'c"d', 'e\nf', ' g ']
def test_csv_dataset_chinese():
TEST_FILE = '../data/dataset/testCSV/chinese.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4', 'col5'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8"),
d['col5'].item().decode("utf8")])
assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好']
def test_csv_dataset_header():
TEST_FILE = '../data/dataset/testCSV/header.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_number():
TEST_FILE = '../data/dataset/testCSV/number.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=[0.0, 0.0, 0, 0.0],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item(),
d['col2'].item(),
d['col3'].item(),
d['col4'].item()])
assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
def test_csv_dataset_size():
TEST_FILE = '../data/dataset/testCSV/size.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=[0.0, 0.0, 0, 0.0],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
assert data.get_dataset_size() == 5
def test_csv_dataset_exception():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator():
pass
assert "Failed to parse CSV file" in str(err.value)
def test_csv_dataset_type_error():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", 0, "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator():
pass
assert "invalid argument of stoi" in str(err.value)
if __name__ == "__main__":
test_csv_dataset_basic()
test_csv_dataset_one_file()
test_csv_dataset_all_file()
test_csv_dataset_num_samples()
test_csv_dataset_distribution()
test_csv_dataset_quoted()
test_csv_dataset_separated()
test_csv_dataset_embedded()
test_csv_dataset_chinese()
test_csv_dataset_header()
test_csv_dataset_number()
test_csv_dataset_size()
test_csv_dataset_exception()
test_csv_dataset_type_error()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册