diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc index 26b823200d6fbd913dcded412da77bda7ebb3517..44eda2dfc1bc16663fe0d9262ec6f4e4c016ac84 100644 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc @@ -31,6 +31,7 @@ #include "minddata/dataset/engine/datasetops/source/celeba_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/manifest_op.h" @@ -88,6 +89,7 @@ static std::unordered_map g_parse_op_func_ = { {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, {kClue, &DEPipeline::ParseClueOp}, {kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}, + {kCsv, &DEPipeline::ParseCsvOp}, {kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}}; DEPipeline::DEPipeline() : iterator_(nullptr) { @@ -1838,6 +1840,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr cache_client, int num return Status::OK(); } +Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector files_list; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetCsvFilesList(files_list); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + + // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; + std::vector col_names; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "field_delim") { + (void)builder->SetFieldDelim(ToString(value)[0]); + } else if (key == "column_defaults") { + py::list py_object_list = py::reinterpret_borrow(value); + std::vector> column_default_list; + for (auto l : py_object_list) { + std::string type_s = (std::string)py::str(l.get_type().attr("__name__")); + if (type_s == "int") { + column_default_list.push_back(std::make_shared>(CsvOp::INT, ToInt(l))); + } else if (type_s == "float") { + column_default_list.push_back(std::make_shared>(CsvOp::FLOAT, ToFloat(l))); + } else if (type_s == "str") { + column_default_list.push_back(std::make_shared>(CsvOp::STRING, ToString(l))); + } else { + RETURN_STATUS_UNEXPECTED("Record type is not allowed"); + } + } + (void)builder->SetColumDefault(column_default_list); + } else if (key == "column_names") { + col_names = ToStringVector(value); + (void)builder->SetColumName(col_names); + } + } + } + + std::shared_ptr csv_op; + RETURN_IF_NOT_OK(builder->Build(&csv_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op)); + *top = csv_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(files_list, col_names.empty(), &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, csv_op, &shuffle_op)); + *top = shuffle_op; + *bottom = csv_op; + } + + return Status::OK(); +} + // Helper function to inject a shuffle operator over top of the current operation being built. Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, std::shared_ptr *shuffle_op) { diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h index 3f1707f67c0bd5570463c58ba0c92e8650381a09..80d524982ab6d07c6c2e829dbc762ce16fd88a92 100644 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h @@ -73,6 +73,7 @@ enum OpName { kClue, kEpochCtrl, kSentencePieceVocab, + kCsv }; // The C++ binder class that we expose to the python script. @@ -201,6 +202,8 @@ class DEPipeline { Status ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + Status ParseCsvOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + private: // Execution tree that links the dataset operators. std::shared_ptr tree_; diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc index 3edb2771754009bc901e5c90ce1153b171b2d2d9..2768ab192c2a15e169d5d77db47ab292ede3c538 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -19,6 +19,7 @@ #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/io_block.h" @@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) { return count; }); + (void)py::class_>(*m, "CsvOp") + .def_static("get_num_rows", [](const py::list &files, bool csv_header) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); + } + THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count)); + return count; + }); + (void)py::class_>(*m, "VOCOp") .def_static("get_num_rows", [](const std::string &dir, const std::string &task_type, const std::string &task_mode, @@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) .value("CELEBA", OpName::kCelebA) .value("TEXTFILE", OpName::kTextFile) - .value("CLUE", OpName::kClue) - .value("EPOCHCTRL", OpName::kEpochCtrl); + .value("EPOCHCTRL", OpName::kEpochCtrl) + .value("CSV", OpName::kCsv) + .value("CLUE", OpName::kClue); (void)py::enum_(m, "JiebaMode", py::arithmetic()) .value("DE_JIEBA_MIX", JiebaMode::kMix) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt index 389e3f5af6da22d77aeadc92816b4743bf4caec9..868c6fdb8914a9f848da639f08ccbaa5d41224c8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt @@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES celeba_op.cc text_file_op.cc clue_op.cc + csv_op.cc ) set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES @@ -29,4 +30,4 @@ if (ENABLE_PYTHON) ) endif() -add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) \ No newline at end of file +add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8bc4d4fcdc056bb3d700f45ce7dfd51197410bed --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -0,0 +1,757 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/csv_op.h" + +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +CsvOp::Builder::Builder() + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status CsvOp::Builder::ValidateInputs() const { + std::string err; + err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; + err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); +} + +Status CsvOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_csv_files_list_.size()) { + builder_num_workers_ = builder_csv_files_list_.size(); + MS_LOG(WARNING) << "CsvOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + std::shared_ptr csv_op = std::make_shared( + builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, + builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, + builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); + RETURN_IF_NOT_OK(csv_op->Init()); + *op = std::move(csv_op); + + return Status::OK(); +} + +CsvOp::CsvOp(const std::vector &csv_files_list, char field_delim, + const std::vector> &column_default, + const std::vector &column_name, int32_t num_workers, int64_t rows_per_buffer, + int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, + int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), + csv_files_list_(std::move(csv_files_list)), + field_delim_(field_delim), + column_default_list_(column_default), + column_name_list_(column_name), + rows_per_buffer_(rows_per_buffer), + num_rows_per_shard_(0), + all_num_rows_(0), + num_samples_(num_samples), + filename_index_(std::make_unique()), + load_jagged_connector_(true), + shuffle_files_(shuffle_files), + finished_reading_dataset_(false), + num_devices_(num_device), + device_id_(device_id), + load_io_block_queue_(true) { + worker_connector_size_ = worker_connector_size; +} + +Status CsvOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(csv_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + jagged_buffer_connector_ = std::make_shared(num_workers_, 1, worker_connector_size_); + + return Status::OK(); +} + +int CsvOp::CsvParser::put_record(char c) { + std::string s = std::string(str_buf_.begin(), str_buf_.begin() + pos_); + std::shared_ptr t; + switch (column_default_[cur_col_]->type) { + case CsvOp::INT: + Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)); + t->SetItemAt({0}, std::stoi(s)); + break; + case CsvOp::FLOAT: + Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32)); + t->SetItemAt({0}, std::stof(s)); + break; + case CsvOp::STRING: + Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); + break; + default: + Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); + break; + } + (*tensor_table_)[cur_row_][cur_col_] = std::move(t); + pos_ = 0; + cur_col_++; + return 0; +} + +int CsvOp::CsvParser::put_row(char c) { + if (total_rows_ < start_offset_) { + total_rows_++; + cur_col_ = 0; + return 0; + } + + if (total_rows_ >= end_offset_) { + return 0; + } + + put_record(c); + total_rows_++; + cur_row_++; + cur_col_ = 0; + + if (cur_row_ == csv_rows_per_buffer_) { + cur_buffer_->set_tensor_table(std::move(tensor_table_)); + buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); + + cur_buffer_ = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + tensor_table_ = std::make_unique(); + cur_row_ = 0; + } + return 0; +} + +int CsvOp::CsvParser::end_file(char c) { + if (cur_col_ > 0) { + put_row(c); + } + if (cur_row_ > 0) { + cur_buffer_->set_tensor_table(std::move(tensor_table_)); + buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); + } + return 0; +} + +int CsvOp::CsvParser::countRows(char c) { + Message m; + if (c == '"') { + m = Message::MS_QUOTE; + } else if (c == '\r' || c == '\n' || c == std::char_traits::eof()) { + m = Message::MS_END_OF_LINE; + } else { + m = Message::MS_NORMAL; + } + StateDiagram::iterator it = sdl.find({cur_state_, m}); + if (it == sd.end()) { + return -1; + } + cur_state_ = it->second.first; + return it->second.second(*this, c); +} + +Status CsvOp::CsvParser::initCsvParser() { + str_buf_.resize(CSV_BUFFER_SIZE); + + // State diagram for counting rows + sdl = {// START_OF_FILE + // ┌───────────┬───────────┬─────────────┐ + // │ abc │ " │ \n │ + // ├───────────┼───────────┼─────────────┤ + // │ UNQUOTE │ QUOTE │ END_OF_LINE │ + // ├───────────┼───────────┼─────────────┤ + // | null_func │ null_func │ null_func │ + // └───────────┴───────────┴─────────────┘ + {{State::START_OF_FILE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, + {{State::START_OF_FILE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, + {{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, + + // UNQUOTE + // ┌───────────┬───────────┬─────────────┐ + // │ abc │ " │ \n │ + // ├───────────┼───────────┼─────────────┤ + // │ UNQUOTE │ QUOTE │ END_OF_LINE │ + // ├───────────┼───────────┼─────────────┤ + // | null_func │ null_func │ add_row │ + // └───────────┴───────────┴─────────────┘ + {{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, + {{State::UNQUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, + {{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}}, + + // QUOTE + // ┌───────────┬──────────────┬───────────┐ + // │ abc │ " │ \n │ + // ├───────────┼──────────────┼───────────┤ + // │ QUOTE │ SECOND_QUOTE │ QUOTE │ + // ├───────────┼──────────────┼───────────┤ + // | null_func │ null_func │ null_func │ + // └───────────┴──────────────┴───────────┘ + {{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::null_func}}, + {{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}}, + {{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::null_func}}, + + // SECOND_QUOTE + // ┌───────────┬───────────┬─────────────┐ + // │ abc │ " │ \n │ + // ├───────────┼───────────┼─────────────┤ + // │ UNQUOTE │ QUOTE │ END_OF_LINE │ + // ├───────────┼───────────┼─────────────┤ + // | null_func │ null_func │ add_row │ + // └───────────┴───────────┴─────────────┘ + {{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, + {{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, + {{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::add_row}}, + + // END_OF_LINE + // ┌───────────┬───────────┬─────────────┐ + // │ abc │ " │ \n │ + // ├───────────┼───────────┼─────────────┤ + // │ UNQUOTE │ QUOTE │ END_OF_LINE │ + // ├───────────┼───────────┼─────────────┤ + // | null_func │ null_func │ null_func │ + // └───────────┴───────────┴─────────────┘ + {{State::END_OF_LINE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::null_func}}, + {{State::END_OF_LINE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::null_func}}, + {{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}}; + + // State diagram for CSV parser + sd = {// START_OF_FILE + // ┌───────────┬──────────┬──────────┬────────────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼──────────┼──────────┼────────────────┼────────────────┤ + // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ + // ├───────────┼──────────┼──────────┼────────────────┼────────────────┤ + // | lambda │ lambda │ lambda │ null_func │ null_func │ + // └───────────┴──────────┴──────────┴────────────────┴────────────────┘ + {{State::START_OF_FILE, Message::MS_NORMAL}, + {State::UNQUOTE, + [this](CsvParser &, char c) -> int { + this->tensor_table_ = std::make_unique(); + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + this->str_buf_[0] = c; + this->pos_ = 1; + return 0; + }}}, + {{State::START_OF_FILE, Message::MS_DELIM}, + {State::DELIM, + [this](CsvParser &, char c) -> int { + this->tensor_table_ = std::make_unique(); + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + this->put_record(c); + return 0; + }}}, + {{State::START_OF_FILE, Message::MS_QUOTE}, + {State::QUOTE, + [this](CsvParser &, char c) -> int { + this->tensor_table_ = std::make_unique(); + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + this->pos_ = 0; + return 0; + }}}, + {{State::START_OF_FILE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, + {{State::START_OF_FILE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::null_func}}, + + // UNQUOTE + // ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ + // │ UNQUOTE │ DELIM │ EXCEPTION │ END_OF_LINE │ END_OF_FILE │ + // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ + // | put_char │ put_record │ exception │ put_row │ end_file │ + // └───────────┴────────────┴───────────┴─────────────┴────────────────┘ + {{State::UNQUOTE, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}}, + {{State::UNQUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, + {{State::UNQUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, + {{State::UNQUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, + // UNQUOTE-Exception + {{State::UNQUOTE, Message::MS_QUOTE}, {State::EXCEPTION, &CsvParser::catch_exception}}, + + // DELIM + // ┌───────────┬────────────┬───────────┬─────────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ + // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ + // ├───────────┼────────────┼───────────┼─────────────┼────────────────┤ + // | put_char │ put_record │ lambda │ put_row │ end_file │ + // └───────────┴────────────┴───────────┴─────────────┴────────────────┘ + {{State::DELIM, Message::MS_NORMAL}, {State::UNQUOTE, &CsvParser::put_char}}, + {{State::DELIM, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, + {{State::DELIM, Message::MS_QUOTE}, + {State::QUOTE, + [this](CsvParser &, char c) -> int { + this->pos_ = 0; + return 0; + }}}, + {{State::DELIM, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, + {{State::DELIM, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, + + // QUOTE + // ┌───────────┬──────────┬──────────────┬──────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼──────────┼──────────────┼──────────┼────────────────┤ + // │ QUOTE │ QUOTE │ SECOND_QUOTE │ QUOTE │ EXCEPTION │ + // ├───────────┼──────────┼──────────────┼──────────┼────────────────┤ + // | put_char │ put_char │ null_func │ put_char │ exception │ + // └───────────┴──────────┴──────────────┴──────────┴────────────────┘ + {{State::QUOTE, Message::MS_NORMAL}, {State::QUOTE, &CsvParser::put_char}}, + {{State::QUOTE, Message::MS_DELIM}, {State::QUOTE, &CsvParser::put_char}}, + {{State::QUOTE, Message::MS_QUOTE}, {State::SECOND_QUOTE, &CsvParser::null_func}}, + {{State::QUOTE, Message::MS_END_OF_LINE}, {State::QUOTE, &CsvParser::put_char}}, + // QUOTE-Exception + {{State::QUOTE, Message::MS_END_OF_FILE}, {State::EXCEPTION, &CsvParser::catch_exception}}, + + // SECOND_QUOTE + // ┌───────────┬────────────┬──────────┬─────────────┬────────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├───────────┼────────────┼──────────┼─────────────┼────────────────┤ + // │ EXCEPTION │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ + // ├───────────┼────────────┼──────────┼─────────────┼────────────────┤ + // | exception │ put_record │ put_char │ put_row │ end_file │ + // └───────────┴────────────┴──────────┴─────────────┴────────────────┘ + {{State::SECOND_QUOTE, Message::MS_QUOTE}, {State::QUOTE, &CsvParser::put_char}}, + {{State::SECOND_QUOTE, Message::MS_DELIM}, {State::DELIM, &CsvParser::put_record}}, + {{State::SECOND_QUOTE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::put_row}}, + {{State::SECOND_QUOTE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}, + // SECOND_QUOTE-Exception + {{State::SECOND_QUOTE, Message::MS_NORMAL}, {State::EXCEPTION, &CsvParser::catch_exception}}, + + // END_OF_LINE + // ┌─────────┬────────┬────────┬─────────────┬─────────────┐ + // │ abc │ , │ " │ \n │ EOF │ + // ├─────────┼────────┼────────┼─────────────┼─────────────┤ + // │ UNQUOTE │ DELIM │ QUOTE │ END_OF_LINE │ END_OF_FILE │ + // ├─────────┼────────┼────────┼─────────────┼─────────────┤ + // | lambda │ lambda │ lambda │ null_func │ end_file │ + // └─────────┴────────┴────────┴─────────────┴─────────────┘ + {{State::END_OF_LINE, Message::MS_NORMAL}, + {State::UNQUOTE, + [this](CsvParser &, char c) -> int { + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + this->str_buf_[0] = c; + this->pos_ = 1; + return 0; + }}}, + {{State::END_OF_LINE, Message::MS_DELIM}, + {State::DELIM, + [this](CsvParser &, char c) -> int { + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + this->put_record(c); + return 0; + }}}, + {{State::END_OF_LINE, Message::MS_QUOTE}, + {State::QUOTE, + [this](CsvParser &, char c) -> int { + this->tensor_table_->push_back(TensorRow(column_default_.size(), nullptr)); + return 0; + }}}, + {{State::END_OF_LINE, Message::MS_END_OF_LINE}, {State::END_OF_LINE, &CsvParser::null_func}}, + {{State::END_OF_LINE, Message::MS_END_OF_FILE}, {State::END_OF_FILE, &CsvParser::end_file}}}; + return Status::OK(); +} + +Status CsvOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_); + csv_parser.setStartOffset(start_offset); + csv_parser.setEndOffset(end_offset); + std::ifstream ifs; + ifs.open(file, std::ifstream::in); + if (column_name_list_.empty()) { + std::string tmp; + getline(ifs, tmp); + } + csv_parser.Reset(); + try { + while (ifs.good()) { + char chr = ifs.get(); + if (csv_parser.processMessage(chr) != 0) { + RETURN_STATUS_UNEXPECTED("Failed to parse CSV file " + file + ":" + std::to_string(csv_parser.total_rows_)); + } + } + } catch (std::invalid_argument &ia) { + std::string err_row = std::to_string(csv_parser.total_rows_); + RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", invalid argument of " + std::string(ia.what())); + } catch (std::out_of_range &oor) { + std::string err_row = std::to_string(csv_parser.total_rows_); + RETURN_STATUS_UNEXPECTED(file + ":" + err_row + ", out of Range error: " + std::string(oor.what())); + } + return Status::OK(); +} + +Status CsvOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this))); + + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + NotifyToFillIOBlockQueue(); + + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (num_samples_ == 0 || rows_read < num_samples_) { + if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { + int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + return Status::OK(); +} + +Status CsvOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// A print method typically used for debugging +void CsvOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ + << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n"; + for (int i = 0; i < csv_files_list_.size(); ++i) { + out << " " << csv_files_list_[i]; + } + out << "\n\n"; + } +} + +// Pops an element from a queue in io_block_queues +Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +Status CsvOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +Status CsvOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status CsvOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +Status CsvOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API CsvDataset. Please check file path or dataset API " + "validation first."); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +int64_t CsvOp::CountTotalRows(const std::string &file) { + CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_); + std::ifstream ifs; + ifs.open(file, std::ifstream::in); + if (column_name_list_.empty()) { + std::string tmp; + getline(ifs, tmp); + } + csv_parser.Reset(); + while (ifs.good()) { + char chr = ifs.get(); + if (csv_parser.countRows(chr) != 0) { + break; + } + } + + return csv_parser.total_rows_; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status CsvOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +Status CsvOp::CountAllFileRows(const std::vector &files, bool csv_header, int64_t *count) { + std::shared_ptr op; + *count = 0; + if (csv_header) { + RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).Build(&op)); + } else { + RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).SetColumName({""}).Build(&op)); + } + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} + +std::vector CsvOp::split(const std::string &s, char delim) { + std::vector res; + std::stringstream ss(s); + std::string item; + + while (getline(ss, item, delim)) { + res.push_back(item); + } + return res; +} + +Status CsvOp::ComputeColMap() { + // Set the column name mapping (base class field) + if (column_name_id_map_.empty()) { + if (column_name_list_.empty()) { + std::string line; + std::ifstream handle(csv_files_list_[0]); + getline(handle, line); + std::vector col_names = split(line, field_delim_); + for (int32_t i = 0; i < col_names.size(); i++) { + column_name_id_map_[col_names[i]] = i; + } + } else { + for (int32_t i = 0; i < column_name_list_.size(); i++) { + column_name_id_map_[column_name_list_[i]] = i; + } + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a456549e75631d51218068e55c583f1cc378fcba --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h @@ -0,0 +1,451 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" + +namespace mindspore { +namespace dataset { + +const size_t CSV_BUFFER_SIZE = 4096; +using StringIndex = AutoIndexObj; +class JaggedConnector; + +class CsvOp : public ParallelOp { + public: + enum RecordType : uint8_t { INT = 0, FLOAT, STRING }; + + struct BaseRecord { + public: + BaseRecord() = default; + explicit BaseRecord(RecordType t) : type(t) {} + virtual ~BaseRecord() {} + RecordType type; + }; + + template + class Record : public BaseRecord { + public: + Record() = default; + Record(RecordType t, T v) : BaseRecord(t), value(v) {} + ~Record() {} + T value; + }; + + // CsvParser is a class that parsing CSV file. + // We design a state machine to implement CSV syntactic analysis. It contains two state diagram,'sd' and 'sdl'. + // The 'sd' is used for parsing CSV syntactic, it's complete and complicate. + // The 'sdl' is used for counting the record rows, it's concise and it runs fast. + struct CsvParser { + public: + CsvParser() = delete; + + CsvParser(int32_t worker_id, std::shared_ptr connector, int64_t rows_per_buffer, char field_delim, + std::vector> column_default) + : worker_id_(worker_id), + buffer_connector_(connector), + csv_rows_per_buffer_(rows_per_buffer), + csv_field_delim_(field_delim), + column_default_(column_default), + cur_state_(START_OF_FILE), + pos_(0), + cur_row_(0), + cur_col_(0), + total_rows_(0), + start_offset_(0), + end_offset_(std::numeric_limits::max()) { + cur_buffer_ = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + initCsvParser(); + } + + ~CsvParser() = default; + + void Reset() { + cur_state_ = START_OF_FILE; + pos_ = 0; + cur_row_ = 0; + cur_col_ = 0; + } + + void setStartOffset(int64_t start_offset) { start_offset_ = start_offset; } + + void setEndOffset(int64_t end_offset) { end_offset_ = end_offset; } + + int processMessage(char c) { + Message m = getMessage(c); + StateDiagram::iterator it = sd.find({cur_state_, m}); + if (it == sd.end()) { + return -1; + } + cur_state_ = it->second.first; + return it->second.second(*this, c); + } + + int countRows(char c); + + Status initCsvParser(); + + enum State : uint8_t { + START_OF_FILE = 0, + UNQUOTE, + DELIM, + QUOTE, + SECOND_QUOTE, + END_OF_LINE, + END_OF_FILE, + EXCEPTION + }; + + enum Message : uint8_t { + MS_NORMAL = 0, + MS_DELIM, + MS_QUOTE, + MS_END_OF_LINE, + MS_END_OF_FILE, + }; + + typedef std::pair StateMessagePair; + typedef std::pair> StateActionPair; + typedef std::map StateDiagram; + + Message getMessage(char c) { + if (c == csv_field_delim_) { + return Message::MS_DELIM; + } else if (c == '"') { + return Message::MS_QUOTE; + } else if (c == '\r' || c == '\n') { + return Message::MS_END_OF_LINE; + } else if (c == std::char_traits::eof()) { + return Message::MS_END_OF_FILE; + } else { + return Message::MS_NORMAL; + } + } + + int null_func(char c) { return 0; } + + int put_char(char c) { + if (pos_ >= str_buf_.size()) { + str_buf_.resize(str_buf_.size() * 2); + } + str_buf_[pos_] = c; + pos_++; + return 0; + } + + int put_record(char c); + + int put_row(char c); + + int end_file(char c); + + int add_row(char c) { + total_rows_++; + return 0; + } + + int catch_exception(char c) { + MS_LOG(ERROR) << "Invalid syntax!"; + return -1; + } + + int32_t worker_id_; + std::shared_ptr buffer_connector_; + int64_t csv_rows_per_buffer_; + const char csv_field_delim_; + std::vector> column_default_; + State cur_state_; + size_t pos_; + int cur_row_; + int cur_col_; + int64_t total_rows_; + int64_t start_offset_; + int64_t end_offset_; + StateDiagram sd; + StateDiagram sdl; + std::vector str_buf_; + std::unique_ptr tensor_table_; + std::unique_ptr cur_buffer_; + }; + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetCsvFilesList(const std::vector &files_list) { + builder_csv_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumSamples(int64_t num_samples) { + builder_num_samples_ = num_samples; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetFieldDelim(char field_delim) { + builder_field_delim_ = field_delim; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColumDefault(std::vector> record_list) { + builder_column_default_list_ = record_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColumName(std::vector col_name_list) { + builder_column_name_list_ = col_name_list; + return *this; + } + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_num_samples_; + int32_t builder_worker_connector_size_; + std::vector builder_csv_files_list_; + bool builder_shuffle_files_; + char builder_field_delim_; + std::vector> builder_column_default_list_; + std::vector builder_column_name_list_; + }; + + // Constructor of CsvOp + CsvOp() = delete; + + CsvOp(const std::vector &csv_files_list, char field_delim, + const std::vector> &column_default, const std::vector &column_name, + int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Default destructor + ~CsvOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all csv files. + // @param csv_header - a bool that indicates csv file include header line + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, bool csv_header, int64_t *count); + + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return csv_files_list_; } + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a csv file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - csv file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + // Split string based on a character delimiter + // @return - the a string vector + std::vector split(const std::string &s, char delim); + + int32_t device_id_; + bool shuffle_files_; + bool finished_reading_dataset_; + int32_t num_devices_; + int64_t rows_per_buffer_; + bool load_io_block_queue_; + int64_t num_rows_per_shard_; + int64_t all_num_rows_; + int64_t num_samples_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + std::vector csv_files_list_; + WaitPost io_block_queue_wait_post_; + std::shared_ptr jagged_buffer_connector_; + QueueList> io_block_queues_; + bool load_jagged_connector_; + char field_delim_; + std::vector> column_default_list_; + std::vector column_name_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_ diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index b2d26b41eeec0337efff79c013233a73e36eeb79..eb9444a05a4d871b64c7bcc6eadd71528f71c7d2 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -21,7 +21,7 @@ can also create samplers with this module to sample data. from .core import config from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ - TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset + TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ WeightedRandomSampler, Sampler from .engine.cache_client import DatasetCache @@ -31,5 +31,5 @@ from .engine.graphdata import GraphData __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", - "CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler", + "CocoDataset", "TextFileDataset", "CLUEDataset", "CSVDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index b3624e1ca34ef2254ebd9ad29ee94b730ca80593..15eb66e54eb3a87dfe8bc95622d6beb7147157f3 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -29,7 +29,7 @@ from .samplers import * from ..core import config __all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", - "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", + "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "CSVDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8aa6dff4f39e48802e4e9674cb639b11e9aa92f3..3cbb7ed2a343abd69f95abaca4447c133a878eb4 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -33,7 +33,7 @@ import copy import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ - MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo + MindRecordOp, TextFileOp, ClueOp, CsvOp, VOCOp, CocoOp, CBatchInfo from mindspore._c_expression import typing from mindspore import log as logger @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save + check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE @@ -1012,7 +1012,7 @@ class Dataset: if isinstance(sampler, samplers.DistributedSampler): dev_id = sampler.shard_id return "", dev_id - if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)): + if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset, CSVDataset)): if output_dataset.shard_id is not None: dev_id = output_dataset.shard_id return "", dev_id @@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset): } Args: - dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of - files. The list will be sorted in a lexicographical order. + dataset_files (str or a list of strings): String or list of files to be read or glob strings to search for + a pattern of files. The list will be sorted in a lexicographical order. task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. (default=AFQMC). usage (str, optional): Need train, test or eval data (default="train"). @@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset): return False +class CSVDataset(SourceDataset): + """ + A source dataset that reads and parses CSV datasets. + + Args: + dataset_files (str or a list of strings): String or list of files to be read or glob strings to search + for a pattern of files. The list will be sorted in a lexicographical order. + field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=','). + column_defaults (list, optional): List of default values for the CSV field (default=None). Each item + in the list is either a valid type (float, int, or string). If this is not provided, treats all + columns as string type. + column_names (list of string, optional): List of column names of the dataset (default=None). If this + is not provided, infers the column_names from the first row of CSV file. + num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_parallel_workers (int, optional): number of workers to read the data + (default=None, number set in the config). + shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). + If shuffle is False, no shuffling will be performed; + If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL + Otherwise, there are two levels of shuffling: + + - Shuffle.GLOBAL: Shuffle both the files and samples. + + - Shuffle.FILES: Shuffle files only. + + num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. + + Examples: + >>> import mindspore.dataset as ds + >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files + >>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4']) + """ + + @check_csvdataset + def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, + num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + super().__init__(num_parallel_workers) + self.dataset_files = self._find_files(dataset_files) + self.dataset_files.sort() + self.field_delim = field_delim + self.column_defaults = column_defaults + self.column_names = column_names + self.num_samples = num_samples + + if not isinstance(shuffle, (bool, Shuffle)): + raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") + if not isinstance(shuffle, Shuffle): + if shuffle: + self.shuffle_level = Shuffle.GLOBAL + self.shuffle_files = True + else: + self.shuffle_level = None + self.shuffle_files = False + else: + self.shuffle_level = shuffle + self.shuffle_files = True + + self.num_shards = num_shards + self.shard_id = shard_id + + def get_args(self): + args = super().get_args() + args["dataset_files"] = self.dataset_files + args['field_delim'] = self.field_delim + args['column_defaults'] = self.column_defaults + args['column_names'] = self.column_names + args["num_samples"] = self.num_samples + if self.shuffle_files is not None: + args["shuffle_files"] = self.shuffle_files + args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) + args["shuffle"] = self.shuffle_level + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + if self._dataset_size is None: + num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) + num_rows = get_num_rows(num_rows, self.num_shards) + if self.num_samples is None: + return num_rows + return min(self.num_samples, num_rows) + return self._dataset_size + + def is_shuffled(self): + return self.shuffle_files + + def is_sharded(self): + if self.num_shards is not None: + return self.num_shards > 1 + + return False + + class TextFileDataset(SourceDataset): """ A source dataset that reads and parses datasets stored on disk in text format. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 8c81dade67efa0c65b68137d53cab80d90da1fdf..e2d810b29a0e3b96172ae7302407f50838a1595f 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -185,6 +185,8 @@ class Iterator: op_type = OpName.SENTENCEPIECEVOCAB elif isinstance(dataset, de.CLUEDataset): op_type = OpName.CLUE + elif isinstance(dataset, de.CSVDataset): + op_type = OpName.CSV else: raise ValueError("Unsupported DatasetOp") diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 23f24632a0913e5e8fbc2e7220da294181e2ec22..99826854caa94dbe923fcff8d55f5cc458366a43 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -787,6 +787,49 @@ def check_cluedataset(method): return new_method +def check_csvdataset(method): + """A wrapper that wrap a parameter checker to the original Dataset(CSVDataset).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + + # check dataset_files; required argument + dataset_files = param_dict.get('dataset_files') + type_check(dataset_files, (str, list), "dataset files") + + # check field_delim + field_delim = param_dict.get('field_delim') + type_check(field_delim, (str,), 'field delim') + if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: + raise ValueError("field_delim is not legal.") + + # check column_defaults + column_defaults = param_dict.get('column_defaults') + if column_defaults is not None: + if not isinstance(column_defaults, list): + raise TypeError("column_defaults should be type of list.") + for item in column_defaults: + if not isinstance(item, (str, int, float)): + raise TypeError("column type is not legal.") + + # check column_names: must be list of string. + column_names = param_dict.get("column_names") + if column_names is not None: + all_string = all(isinstance(item, str) for item in column_names) + if not all_string: + raise TypeError("column_names should be a list of str.") + + validate_dataset_param_value(nreq_param_int, param_dict, int) + check_sampler_shuffle_shard_options(param_dict) + + return method(self, *args, **kwargs) + + return new_method + + def check_textfiledataset(method): """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 3d9e6ddf8915fa39de1f6bc01a3e3c5043a2fef5..050cc79db3cbe52163a8d164c2e7e4f57b18c05d 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -77,6 +77,7 @@ SET(DE_UT_SRCS celeba_op_test.cc take_op_test.cc clue_op_test.cc + csv_op_test.cc text_file_op_test.cc filter_op_test.cc concat_op_test.cc diff --git a/tests/ut/cpp/dataset/csv_op_test.cc b/tests/ut/cpp/dataset/csv_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2eae7b98f11e8095bbcacf1832c4354c24ee2969 --- /dev/null +++ b/tests/ut/cpp/dataset/csv_op_test.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/core/client.h" +#include "common/common.h" +#include "common/utils.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/engine/datasetops/source/csv_op.h" +#include "minddata/dataset/util/status.h" + + +namespace common = mindspore::common; + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestCSVOp : public UT::DatasetOpTesting { + +}; + +TEST_F(MindDataTestCSVOp, TestCSVBasic) { + // Start with an empty execution tree + auto tree = std::make_shared(); + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testCSV/1.csv"; + + std::vector> column_default_list; + column_default_list.push_back(std::make_shared>(CsvOp::INT, 0)); + column_default_list.push_back(std::make_shared>(CsvOp::INT, 0)); + column_default_list.push_back(std::make_shared>(CsvOp::INT, 0)); + column_default_list.push_back(std::make_shared>(CsvOp::INT, 0)); + std::shared_ptr op; + CsvOp::Builder builder; + builder.SetCsvFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16) + .SetShuffleFiles(false) + .SetOpConnectorSize(2) + .SetFieldDelim(',') + .SetColumDefault(column_default_list) + .SetColumName({"col1", "col2", "col3", "col4"}); + + Status rc = builder.Build(&op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssociateNode(op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssignRoot(op); + ASSERT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration."; + rc = tree->Prepare(); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + int row_count = 0; + while (!tensor_list.empty()) { + // Display the tensor by calling the printer on it + for (int i = 0; i < tensor_list.size(); i++) { + std::ostringstream ss; + ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; + MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + } + + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + row_count++; + } + + ASSERT_EQ(row_count, 3); +} + +TEST_F(MindDataTestCSVOp, TestTotalRows) { + std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv"; + std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv"; + std::vector files; + files.push_back(csv_file1); + int64_t total_rows = 0; + CsvOp::CountAllFileRows(files, false, &total_rows); + ASSERT_EQ(total_rows, 3); + files.clear(); + + files.push_back(csv_file2); + CsvOp::CountAllFileRows(files, false, &total_rows); + ASSERT_EQ(total_rows, 5); + files.clear(); + + files.push_back(csv_file1); + files.push_back(csv_file2); + CsvOp::CountAllFileRows(files, false, &total_rows); + ASSERT_EQ(total_rows, 8); + files.clear(); +} diff --git a/tests/ut/data/dataset/testCSV/1.csv b/tests/ut/data/dataset/testCSV/1.csv new file mode 100644 index 0000000000000000000000000000000000000000..13fbfd70a1077824d6785f453324aecde419f954 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/1.csv @@ -0,0 +1,3 @@ +1,2,3,4 +5,6,7,8 +9,10,11,12 diff --git a/tests/ut/data/dataset/testCSV/2.csv b/tests/ut/data/dataset/testCSV/2.csv new file mode 100644 index 0000000000000000000000000000000000000000..b96a0a4ed386760791f26fdb995a50ecd3a3e622 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/2.csv @@ -0,0 +1,8 @@ +,"222",3,"4""" +"5",6,,"8" +9,10,"1""1",12 +,,"", +,,, + +a,b,c,"" +a,b,c,d diff --git a/tests/ut/data/dataset/testCSV/chinese.csv b/tests/ut/data/dataset/testCSV/chinese.csv new file mode 100644 index 0000000000000000000000000000000000000000..9445c527041abf708943172efd5ca624b31908d7 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/chinese.csv @@ -0,0 +1 @@ +大家,早上好,中午好,下午好,晚上好 diff --git a/tests/ut/data/dataset/testCSV/embedded.csv b/tests/ut/data/dataset/testCSV/embedded.csv new file mode 100644 index 0000000000000000000000000000000000000000..c7e10b013683f8509bf4106daa4616254ccb81f4 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/embedded.csv @@ -0,0 +1,2 @@ +"a,b","c""d","e +f"," g " diff --git a/tests/ut/data/dataset/testCSV/exception.csv b/tests/ut/data/dataset/testCSV/exception.csv new file mode 100644 index 0000000000000000000000000000000000000000..da5357efa57e6d2c61ca6f30121c0bcfa257248a --- /dev/null +++ b/tests/ut/data/dataset/testCSV/exception.csv @@ -0,0 +1,3 @@ +1,2,3,4 +5,6,7,8 +a,"c",d,"e diff --git a/tests/ut/data/dataset/testCSV/header.csv b/tests/ut/data/dataset/testCSV/header.csv new file mode 100644 index 0000000000000000000000000000000000000000..bf14e15263adfab4a8fdda9d4cb9438a3beaa17d --- /dev/null +++ b/tests/ut/data/dataset/testCSV/header.csv @@ -0,0 +1,2 @@ +col1,col2,col3,col4 +a,b,c,d \ No newline at end of file diff --git a/tests/ut/data/dataset/testCSV/number.csv b/tests/ut/data/dataset/testCSV/number.csv new file mode 100644 index 0000000000000000000000000000000000000000..2d3a7ec4c4797adadeed5d8436e95e870deb5926 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/number.csv @@ -0,0 +1 @@ +3,0.3,4,55.5 diff --git a/tests/ut/data/dataset/testCSV/quoted.csv b/tests/ut/data/dataset/testCSV/quoted.csv new file mode 100644 index 0000000000000000000000000000000000000000..5391bb9cc77f7e0c6e2677d216f50b69e268da98 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/quoted.csv @@ -0,0 +1 @@ +"a","b","c","d" diff --git a/tests/ut/data/dataset/testCSV/separated.csv b/tests/ut/data/dataset/testCSV/separated.csv new file mode 100644 index 0000000000000000000000000000000000000000..6a8e0ec28a16244e6beeb9a6271ee9bc4828a408 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/separated.csv @@ -0,0 +1 @@ +a|b|c|d diff --git a/tests/ut/data/dataset/testCSV/size.csv b/tests/ut/data/dataset/testCSV/size.csv new file mode 100644 index 0000000000000000000000000000000000000000..6ba3b2ba71cdc07604a3b057bb0f146bf9d667e8 --- /dev/null +++ b/tests/ut/data/dataset/testCSV/size.csv @@ -0,0 +1,10 @@ +1,2,3,4 +"a","b","c +","d +e" +5,6,7,8 +9,10,11,12 +a,"b +",c,"d +e" + diff --git a/tests/ut/python/dataset/test_datasets_csv.py b/tests/ut/python/dataset/test_datasets_csv.py new file mode 100644 index 0000000000000000000000000000000000000000..021bbe942fbd7cc0766bce89afefc111f84c1abb --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_csv.py @@ -0,0 +1,238 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds +import numpy as np +import pytest + +DATA_FILE = '../data/dataset/testCSV/1.csv' + + +def test_csv_dataset_basic(): + """ + Test CSV with repeat, skip and so on + """ + TRAIN_FILE = '../data/dataset/testCSV/1.csv' + + buffer = [] + data = ds.CSVDataset( + TRAIN_FILE, + column_defaults=["0", 0, 0.0, "0"], + column_names=['1', '2', '3', '4'], + shuffle=False) + data = data.repeat(2) + data = data.skip(2) + for d in data.create_dict_iterator(): + buffer.append(d) + assert len(buffer) == 4 + + +def test_csv_dataset_one_file(): + data = ds.CSVDataset( + DATA_FILE, + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.append(d) + assert len(buffer) == 3 + + +def test_csv_dataset_all_file(): + APPEND_FILE = '../data/dataset/testCSV/2.csv' + data = ds.CSVDataset( + [DATA_FILE, APPEND_FILE], + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.append(d) + assert len(buffer) == 10 + + +def test_csv_dataset_num_samples(): + data = ds.CSVDataset( + DATA_FILE, + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False, num_samples=2) + count = 0 + for _ in data.create_dict_iterator(): + count += 1 + assert count == 2 + + +def test_csv_dataset_distribution(): + TEST_FILE = '../data/dataset/testCSV/1.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False, num_shards=2, shard_id=0) + count = 0 + for _ in data.create_dict_iterator(): + count += 1 + assert count == 2 + + +def test_csv_dataset_quoted(): + TEST_FILE = '../data/dataset/testCSV/quoted.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8")]) + assert buffer == ['a', 'b', 'c', 'd'] + + +def test_csv_dataset_separated(): + TEST_FILE = '../data/dataset/testCSV/separated.csv' + data = ds.CSVDataset( + TEST_FILE, + field_delim='|', + column_defaults=["", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8")]) + assert buffer == ['a', 'b', 'c', 'd'] + + +def test_csv_dataset_embedded(): + TEST_FILE = '../data/dataset/testCSV/embedded.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8")]) + assert buffer == ['a,b', 'c"d', 'e\nf', ' g '] + + +def test_csv_dataset_chinese(): + TEST_FILE = '../data/dataset/testCSV/chinese.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4', 'col5'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8"), + d['col5'].item().decode("utf8")]) + assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好'] + + +def test_csv_dataset_header(): + TEST_FILE = '../data/dataset/testCSV/header.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", ""], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item().decode("utf8"), + d['col2'].item().decode("utf8"), + d['col3'].item().decode("utf8"), + d['col4'].item().decode("utf8")]) + assert buffer == ['a', 'b', 'c', 'd'] + + +def test_csv_dataset_number(): + TEST_FILE = '../data/dataset/testCSV/number.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=[0.0, 0.0, 0, 0.0], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + buffer = [] + for d in data.create_dict_iterator(): + buffer.extend([d['col1'].item(), + d['col2'].item(), + d['col3'].item(), + d['col4'].item()]) + assert np.allclose(buffer, [3.0, 0.3, 4, 55.5]) + + +def test_csv_dataset_size(): + TEST_FILE = '../data/dataset/testCSV/size.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=[0.0, 0.0, 0, 0.0], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + assert data.get_dataset_size() == 5 + + +def test_csv_dataset_exception(): + TEST_FILE = '../data/dataset/testCSV/exception.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", "", "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + with pytest.raises(Exception) as err: + for _ in data.create_dict_iterator(): + pass + assert "Failed to parse CSV file" in str(err.value) + + +def test_csv_dataset_type_error(): + TEST_FILE = '../data/dataset/testCSV/exception.csv' + data = ds.CSVDataset( + TEST_FILE, + column_defaults=["", 0, "", ""], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + with pytest.raises(Exception) as err: + for _ in data.create_dict_iterator(): + pass + assert "invalid argument of stoi" in str(err.value) + + +if __name__ == "__main__": + test_csv_dataset_basic() + test_csv_dataset_one_file() + test_csv_dataset_all_file() + test_csv_dataset_num_samples() + test_csv_dataset_distribution() + test_csv_dataset_quoted() + test_csv_dataset_separated() + test_csv_dataset_embedded() + test_csv_dataset_chinese() + test_csv_dataset_header() + test_csv_dataset_number() + test_csv_dataset_size() + test_csv_dataset_exception() + test_csv_dataset_type_error()