提交 b091f74c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3016 Add CSV dataset loader

Merge pull request !3016 from jiangzhiwen/dataset/csv
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "minddata/dataset/engine/datasetops/source/celeba_op.h" #include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h" #include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h" #include "minddata/dataset/engine/datasetops/source/manifest_op.h"
...@@ -88,6 +89,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = { ...@@ -88,6 +89,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kBuildVocab, &DEPipeline::ParseBuildVocabOp}, {kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}, {kClue, &DEPipeline::ParseClueOp},
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}, {kEpochCtrl, &DEPipeline::ParseEpochCtrlOp},
{kCsv, &DEPipeline::ParseCsvOp},
{kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}}; {kSentencePieceVocab, &DEPipeline::ParseBuildSentencePieceVocabOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) { DEPipeline::DEPipeline() : iterator_(nullptr) {
...@@ -1848,6 +1850,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num ...@@ -1848,6 +1850,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num
return Status::OK(); return Status::OK();
} }
Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>();
if (!args["dataset_files"].is_none()) {
files_list = ToStringVector(args["dataset_files"]);
(void)builder->SetCsvFilesList(files_list);
} else {
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
}
// Optional arguments
bool shuffle_required = false;
int64_t num_devices = 0;
std::vector<std::string> col_names;
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
shuffle_required = ToBool(value);
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_shards") {
num_devices = ToInt(value);
(void)builder->SetNumDevices(num_devices);
} else if (key == "shard_id") {
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "field_delim") {
(void)builder->SetFieldDelim(ToString(value)[0]);
} else if (key == "column_defaults") {
py::list py_object_list = py::reinterpret_borrow<py::list>(value);
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
for (auto l : py_object_list) {
std::string type_s = (std::string)py::str(l.get_type().attr("__name__"));
if (type_s == "int") {
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, ToInt(l)));
} else if (type_s == "float") {
column_default_list.push_back(std::make_shared<CsvOp::Record<float>>(CsvOp::FLOAT, ToFloat(l)));
} else if (type_s == "str") {
column_default_list.push_back(std::make_shared<CsvOp::Record<std::string>>(CsvOp::STRING, ToString(l)));
} else {
RETURN_STATUS_UNEXPECTED("Record type is not allowed");
}
}
(void)builder->SetColumDefault(column_default_list);
} else if (key == "column_names") {
col_names = ToStringVector(value);
(void)builder->SetColumName(col_names);
}
}
}
std::shared_ptr<CsvOp> csv_op;
RETURN_IF_NOT_OK(builder->Build(&csv_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op));
*top = csv_op;
if (shuffle_required) {
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t shuffle_size = 0;
int64_t num_rows = 0;
// First, get the number of rows in the dataset and then compute the shuffle size
RETURN_IF_NOT_OK(CsvOp::CountAllFileRows(files_list, col_names.empty(), &num_rows));
RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size));
// Add the shuffle op over top of this op and return the subtree (top/bottom) to caller
RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, csv_op, &shuffle_op));
*top = shuffle_op;
*bottom = csv_op;
}
return Status::OK();
}
// Helper function to inject a shuffle operator over top of the current operation being built. // Helper function to inject a shuffle operator over top of the current operation being built.
Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op, Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *shuffle_op) { std::shared_ptr<DatasetOp> *shuffle_op) {
......
...@@ -73,6 +73,7 @@ enum OpName { ...@@ -73,6 +73,7 @@ enum OpName {
kClue, kClue,
kEpochCtrl, kEpochCtrl,
kSentencePieceVocab, kSentencePieceVocab,
kCsv
}; };
// The C++ binder class that we expose to the python script. // The C++ binder class that we expose to the python script.
...@@ -201,6 +202,8 @@ class DEPipeline { ...@@ -201,6 +202,8 @@ class DEPipeline {
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
private: private:
// Execution tree that links the dataset operators. // Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_; std::shared_ptr<ExecutionTree> tree_;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h" #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h" #include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/io_block.h"
...@@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) { ...@@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) {
return count; return count;
}); });
(void)py::class_<CsvOp, DatasetOp, std::shared_ptr<CsvOp>>(*m, "CsvOp")
.def_static("get_num_rows", [](const py::list &files, bool csv_header) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
}
THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count));
return count;
});
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows", .def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode, [](const std::string &dir, const std::string &task_type, const std::string &task_mode,
...@@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) { ...@@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
.value("CELEBA", OpName::kCelebA) .value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile) .value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue) .value("EPOCHCTRL", OpName::kEpochCtrl)
.value("EPOCHCTRL", OpName::kEpochCtrl); .value("CSV", OpName::kCsv)
.value("CLUE", OpName::kClue);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic()) (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix) .value("DE_JIEBA_MIX", JiebaMode::kMix)
......
...@@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES ...@@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
celeba_op.cc celeba_op.cc
text_file_op.cc text_file_op.cc
clue_op.cc clue_op.cc
csv_op.cc
) )
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
...@@ -29,4 +30,4 @@ if (ENABLE_PYTHON) ...@@ -29,4 +30,4 @@ if (ENABLE_PYTHON)
) )
endif() endif()
add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES})
\ No newline at end of file
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <utility>
#include <limits>
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
namespace mindspore {
namespace dataset {
const size_t CSV_BUFFER_SIZE = 4096;
using StringIndex = AutoIndexObj<std::string>;
class JaggedConnector;
class CsvOp : public ParallelOp {
public:
enum RecordType : uint8_t { INT = 0, FLOAT, STRING };
struct BaseRecord {
public:
BaseRecord() = default;
explicit BaseRecord(RecordType t) : type(t) {}
virtual ~BaseRecord() {}
RecordType type;
};
template <typename T>
class Record : public BaseRecord {
public:
Record() = default;
Record(RecordType t, T v) : BaseRecord(t), value(v) {}
~Record() {}
T value;
};
// CsvParser is a class that parsing CSV file.
// We design a state machine to implement CSV syntactic analysis. It contains two state diagram,'sd' and 'sdl'.
// The 'sd' is used for parsing CSV syntactic, it's complete and complicate.
// The 'sdl' is used for counting the record rows, it's concise and it runs fast.
struct CsvParser {
public:
CsvParser() = delete;
CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim,
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default)
: worker_id_(worker_id),
buffer_connector_(connector),
csv_rows_per_buffer_(rows_per_buffer),
csv_field_delim_(field_delim),
column_default_(column_default),
cur_state_(START_OF_FILE),
pos_(0),
cur_row_(0),
cur_col_(0),
total_rows_(0),
start_offset_(0),
end_offset_(std::numeric_limits<int64_t>::max()) {
cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
initCsvParser();
}
~CsvParser() = default;
void Reset() {
cur_state_ = START_OF_FILE;
pos_ = 0;
cur_row_ = 0;
cur_col_ = 0;
}
void setStartOffset(int64_t start_offset) { start_offset_ = start_offset; }
void setEndOffset(int64_t end_offset) { end_offset_ = end_offset; }
int processMessage(char c) {
Message m = getMessage(c);
StateDiagram::iterator it = sd.find({cur_state_, m});
if (it == sd.end()) {
return -1;
}
cur_state_ = it->second.first;
return it->second.second(*this, c);
}
int countRows(char c);
Status initCsvParser();
enum State : uint8_t {
START_OF_FILE = 0,
UNQUOTE,
DELIM,
QUOTE,
SECOND_QUOTE,
END_OF_LINE,
END_OF_FILE,
EXCEPTION
};
enum Message : uint8_t {
MS_NORMAL = 0,
MS_DELIM,
MS_QUOTE,
MS_END_OF_LINE,
MS_END_OF_FILE,
};
typedef std::pair<State, Message> StateMessagePair;
typedef std::pair<State, std::function<int(CsvParser &, char)>> StateActionPair;
typedef std::map<StateMessagePair, StateActionPair> StateDiagram;
Message getMessage(char c) {
if (c == csv_field_delim_) {
return Message::MS_DELIM;
} else if (c == '"') {
return Message::MS_QUOTE;
} else if (c == '\r' || c == '\n') {
return Message::MS_END_OF_LINE;
} else if (c == std::char_traits<char>::eof()) {
return Message::MS_END_OF_FILE;
} else {
return Message::MS_NORMAL;
}
}
int null_func(char c) { return 0; }
int put_char(char c) {
if (pos_ >= str_buf_.size()) {
str_buf_.resize(str_buf_.size() * 2);
}
str_buf_[pos_] = c;
pos_++;
return 0;
}
int put_record(char c);
int put_row(char c);
int end_file(char c);
int add_row(char c) {
total_rows_++;
return 0;
}
int catch_exception(char c) {
MS_LOG(ERROR) << "Invalid syntax!";
return -1;
}
int32_t worker_id_;
std::shared_ptr<JaggedConnector> buffer_connector_;
int64_t csv_rows_per_buffer_;
const char csv_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_;
State cur_state_;
size_t pos_;
int cur_row_;
int cur_col_;
int64_t total_rows_;
int64_t start_offset_;
int64_t end_offset_;
StateDiagram sd;
StateDiagram sdl;
std::vector<char> str_buf_;
std::unique_ptr<TensorQTable> tensor_table_;
std::unique_ptr<DataBuffer> cur_buffer_;
};
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
// Checks if the inputs of the builder is valid.
// @return Status - the error code returned.
Status ValidateInputs() const;
// Create the final object.
// @param op - dataset op.
// @return - the error code return.
Status Build(std::shared_ptr<CsvOp> *op);
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumDevices(int64_t num_dev) {
builder_num_devices_ = num_dev;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetDeviceId(int64_t dev_id) {
builder_device_id_ = dev_id;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetCsvFilesList(const std::vector<std::string> &files_list) {
builder_csv_files_list_ = files_list;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetShuffleFiles(bool shuffle_files) {
builder_shuffle_files_ = shuffle_files;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetFieldDelim(char field_delim) {
builder_field_delim_ = field_delim;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetColumDefault(std::vector<std::shared_ptr<CsvOp::BaseRecord>> record_list) {
builder_column_default_list_ = record_list;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetColumName(std::vector<std::string> col_name_list) {
builder_column_name_list_ = col_name_list;
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
int32_t builder_worker_connector_size_;
std::vector<std::string> builder_csv_files_list_;
bool builder_shuffle_files_;
char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_;
};
// Constructor of CsvOp
CsvOp() = delete;
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~CsvOp() = default;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status Init();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Get total rows in files.
// @param files - all csv files.
// @param csv_header - a bool that indicates csv file include header line
// @param count - number of rows.
// @return Status - the error coed returned.
static Status CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count);
// File names getter
// @return Vector of the input file names
std::vector<std::string> FileNames() { return csv_files_list_; }
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
// @param row - the id of the row filled in the tensor table.
// @return Status - the error code returned.
Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row);
// Reads a csv file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();
// Count number of rows in each file.
// @param filename - csv file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Private function for computing the assignment of the column name map.
// @return - Status
Status ComputeColMap() override;
// Split string based on a character delimiter
// @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim);
int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> csv_files_list_;
WaitPost io_block_queue_wait_post_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
char field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_;
std::vector<std::string> column_name_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
...@@ -21,7 +21,7 @@ can also create samplers with this module to sample data. ...@@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
from .core import config from .core import config
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset TextFileDataset, CLUEDataset, CSVDataset, Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler WeightedRandomSampler, Sampler
from .engine.cache_client import DatasetCache from .engine.cache_client import DatasetCache
...@@ -31,5 +31,5 @@ from .engine.graphdata import GraphData ...@@ -31,5 +31,5 @@ from .engine.graphdata import GraphData
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset",
"CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler", "CocoDataset", "TextFileDataset", "CLUEDataset", "CSVDataset", "Schema", "DistributedSampler", "PKSampler",
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]
...@@ -29,7 +29,7 @@ from .samplers import * ...@@ -29,7 +29,7 @@ from .samplers import *
from ..core import config from ..core import config
__all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", __all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "CSVDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler",
"PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]
...@@ -33,7 +33,7 @@ import copy ...@@ -33,7 +33,7 @@ import copy
import numpy as np import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo MindRecordOp, TextFileOp, ClueOp, CsvOp, VOCOp, CocoOp, CBatchInfo
from mindspore._c_expression import typing from mindspore._c_expression import typing
from mindspore import log as logger from mindspore import log as logger
...@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che ...@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
...@@ -1012,7 +1012,7 @@ class Dataset: ...@@ -1012,7 +1012,7 @@ class Dataset:
if isinstance(sampler, samplers.DistributedSampler): if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id dev_id = sampler.shard_id
return "", dev_id return "", dev_id
if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)): if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset, CSVDataset)):
if output_dataset.shard_id is not None: if output_dataset.shard_id is not None:
dev_id = output_dataset.shard_id dev_id = output_dataset.shard_id
return "", dev_id return "", dev_id
...@@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset): ...@@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset):
} }
Args: Args:
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of dataset_files (str or a list of strings): String or list of files to be read or glob strings to search for
files. The list will be sorted in a lexicographical order. a pattern of files. The list will be sorted in a lexicographical order.
task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'. task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
(default=AFQMC). (default=AFQMC).
usage (str, optional): Need train, test or eval data (default="train"). usage (str, optional): Need train, test or eval data (default="train").
...@@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset): ...@@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset):
return False return False
class CSVDataset(SourceDataset):
"""
A source dataset that reads and parses CSV datasets.
Args:
dataset_files (str or a list of strings): String or list of files to be read or glob strings to search
for a pattern of files. The list will be sorted in a lexicographical order.
field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=',').
column_defaults (list, optional): List of default values for the CSV field (default=None). Each item
in the list is either a valid type (float, int, or string). If this is not provided, treats all
columns as string type.
column_names (list of string, optional): List of column names of the dataset (default=None). If this
is not provided, infers the column_names from the first row of CSV file.
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Examples:
>>> import mindspore.dataset as ds
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
>>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4'])
"""
@check_csvdataset
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
self.field_delim = field_delim
self.column_defaults = column_defaults
self.column_names = column_names
self.num_samples = num_samples
if not isinstance(shuffle, (bool, Shuffle)):
raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
if not isinstance(shuffle, Shuffle):
if shuffle:
self.shuffle_level = Shuffle.GLOBAL
self.shuffle_files = True
else:
self.shuffle_level = None
self.shuffle_files = False
else:
self.shuffle_level = shuffle
self.shuffle_files = True
self.num_shards = num_shards
self.shard_id = shard_id
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
args['field_delim'] = self.field_delim
args['column_defaults'] = self.column_defaults
args['column_names'] = self.column_names
args["num_samples"] = self.num_samples
if self.shuffle_files is not None:
args["shuffle_files"] = self.shuffle_files
args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if self._dataset_size is None:
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples is None:
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
def is_shuffled(self):
return self.shuffle_files
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return False
class TextFileDataset(SourceDataset): class TextFileDataset(SourceDataset):
""" """
A source dataset that reads and parses datasets stored on disk in text format. A source dataset that reads and parses datasets stored on disk in text format.
......
...@@ -185,6 +185,8 @@ class Iterator: ...@@ -185,6 +185,8 @@ class Iterator:
op_type = OpName.SENTENCEPIECEVOCAB op_type = OpName.SENTENCEPIECEVOCAB
elif isinstance(dataset, de.CLUEDataset): elif isinstance(dataset, de.CLUEDataset):
op_type = OpName.CLUE op_type = OpName.CLUE
elif isinstance(dataset, de.CSVDataset):
op_type = OpName.CSV
else: else:
raise ValueError("Unsupported DatasetOp") raise ValueError("Unsupported DatasetOp")
......
...@@ -787,6 +787,49 @@ def check_cluedataset(method): ...@@ -787,6 +787,49 @@ def check_cluedataset(method):
return new_method return new_method
def check_csvdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CSVDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_files; required argument
dataset_files = param_dict.get('dataset_files')
type_check(dataset_files, (str, list), "dataset files")
# check field_delim
field_delim = param_dict.get('field_delim')
type_check(field_delim, (str,), 'field delim')
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
raise ValueError("field_delim is not legal.")
# check column_defaults
column_defaults = param_dict.get('column_defaults')
if column_defaults is not None:
if not isinstance(column_defaults, list):
raise TypeError("column_defaults should be type of list.")
for item in column_defaults:
if not isinstance(item, (str, int, float)):
raise TypeError("column type is not legal.")
# check column_names: must be list of string.
column_names = param_dict.get("column_names")
if column_names is not None:
all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method
def check_textfiledataset(method): def check_textfiledataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
......
...@@ -77,6 +77,7 @@ SET(DE_UT_SRCS ...@@ -77,6 +77,7 @@ SET(DE_UT_SRCS
celeba_op_test.cc celeba_op_test.cc
take_op_test.cc take_op_test.cc
clue_op_test.cc clue_op_test.cc
csv_op_test.cc
text_file_op_test.cc text_file_op_test.cc
filter_op_test.cc filter_op_test.cc
concat_op_test.cc concat_op_test.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "minddata/dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/util/status.h"
namespace common = mindspore::common;
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestCSVOp : public UT::DatasetOpTesting {
};
TEST_F(MindDataTestCSVOp, TestCSVBasic) {
// Start with an empty execution tree
auto tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
std::shared_ptr<CsvOp> op;
CsvOp::Builder builder;
builder.SetCsvFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetNumWorkers(16)
.SetShuffleFiles(false)
.SetOpConnectorSize(2)
.SetFieldDelim(',')
.SetColumDefault(column_default_list)
.SetColumName({"col1", "col2", "col3", "col4"});
Status rc = builder.Build(&op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssignRoot(op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 3);
}
TEST_F(MindDataTestCSVOp, TestTotalRows) {
std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv";
std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv";
std::vector<std::string> files;
files.push_back(csv_file1);
int64_t total_rows = 0;
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 3);
files.clear();
files.push_back(csv_file2);
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 5);
files.clear();
files.push_back(csv_file1);
files.push_back(csv_file2);
CsvOp::CountAllFileRows(files, false, &total_rows);
ASSERT_EQ(total_rows, 8);
files.clear();
}
,"222",3,"4"""
"5",6,,"8"
9,10,"1""1",12
,,"",
,,,
a,b,c,""
a,b,c,d
大家,早上好,中午好,下午好,晚上好
col1,col2,col3,col4
a,b,c,d
\ No newline at end of file
1,2,3,4
"a","b","c
","d
e"
5,6,7,8
9,10,11,12
a,"b
",c,"d
e"
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import numpy as np
import pytest
DATA_FILE = '../data/dataset/testCSV/1.csv'
def test_csv_dataset_basic():
"""
Test CSV with repeat, skip and so on
"""
TRAIN_FILE = '../data/dataset/testCSV/1.csv'
buffer = []
data = ds.CSVDataset(
TRAIN_FILE,
column_defaults=["0", 0, 0.0, "0"],
column_names=['1', '2', '3', '4'],
shuffle=False)
data = data.repeat(2)
data = data.skip(2)
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 4
def test_csv_dataset_one_file():
data = ds.CSVDataset(
DATA_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 3
def test_csv_dataset_all_file():
APPEND_FILE = '../data/dataset/testCSV/2.csv'
data = ds.CSVDataset(
[DATA_FILE, APPEND_FILE],
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.append(d)
assert len(buffer) == 10
def test_csv_dataset_num_samples():
data = ds.CSVDataset(
DATA_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False, num_samples=2)
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 2
def test_csv_dataset_distribution():
TEST_FILE = '../data/dataset/testCSV/1.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False, num_shards=2, shard_id=0)
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 2
def test_csv_dataset_quoted():
TEST_FILE = '../data/dataset/testCSV/quoted.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_separated():
TEST_FILE = '../data/dataset/testCSV/separated.csv'
data = ds.CSVDataset(
TEST_FILE,
field_delim='|',
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_embedded():
TEST_FILE = '../data/dataset/testCSV/embedded.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a,b', 'c"d', 'e\nf', ' g ']
def test_csv_dataset_chinese():
TEST_FILE = '../data/dataset/testCSV/chinese.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4', 'col5'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8"),
d['col5'].item().decode("utf8")])
assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好']
def test_csv_dataset_header():
TEST_FILE = '../data/dataset/testCSV/header.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item().decode("utf8"),
d['col2'].item().decode("utf8"),
d['col3'].item().decode("utf8"),
d['col4'].item().decode("utf8")])
assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_number():
TEST_FILE = '../data/dataset/testCSV/number.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=[0.0, 0.0, 0, 0.0],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
buffer = []
for d in data.create_dict_iterator():
buffer.extend([d['col1'].item(),
d['col2'].item(),
d['col3'].item(),
d['col4'].item()])
assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
def test_csv_dataset_size():
TEST_FILE = '../data/dataset/testCSV/size.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=[0.0, 0.0, 0, 0.0],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
assert data.get_dataset_size() == 5
def test_csv_dataset_exception():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", "", "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator():
pass
assert "Failed to parse CSV file" in str(err.value)
def test_csv_dataset_type_error():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", 0, "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator():
pass
assert "invalid argument of stoi" in str(err.value)
if __name__ == "__main__":
test_csv_dataset_basic()
test_csv_dataset_one_file()
test_csv_dataset_all_file()
test_csv_dataset_num_samples()
test_csv_dataset_distribution()
test_csv_dataset_quoted()
test_csv_dataset_separated()
test_csv_dataset_embedded()
test_csv_dataset_chinese()
test_csv_dataset_header()
test_csv_dataset_number()
test_csv_dataset_size()
test_csv_dataset_exception()
test_csv_dataset_type_error()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册