提交 2795e492 编写于 作者: Y yanghaitao

TextFileDataset

上级 18580a78
......@@ -28,10 +28,10 @@
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
......@@ -61,7 +61,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kVoc, &DEPipeline::ParseVOCOp},
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp}};
{kCelebA, &DEPipeline::ParseCelebAOp},
{kTextFile, &DEPipeline::ParseTextFileOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
......@@ -985,5 +986,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
// Required arguments
std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
if (!args["dataset_files"].is_none()) {
(void)builder->SetTextFilesList(ToStringVector(args["dataset_files"]));
} else {
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
}
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_shards") {
(void)builder->SetNumDevices(ToInt(value));
} else if (key == "shard_id") {
(void)builder->SetDeviceId(ToInt(value));
}
}
}
std::shared_ptr<TextFileOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -58,7 +58,8 @@ enum OpName {
kVoc,
kCifar10,
kCifar100,
kCelebA
kCelebA,
kTextFile
};
// The C++ binder class that we expose to the python script.
......@@ -148,6 +149,8 @@ class DEPipeline {
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
private:
// Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_;
......
......@@ -55,6 +55,7 @@
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
......@@ -176,6 +177,17 @@ void bindDatasetOps(py::module *m) {
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
return count;
});
(void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
.def_static("get_num_rows", [](const py::list &files) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
!file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
}
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
return count;
});
}
void bindTensor(py::module *m) {
(void)py::class_<GlobalContext>(*m, "GlobalContext")
......@@ -463,7 +475,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("VOC", OpName::kVoc)
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("CELEBA", OpName::kCelebA);
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile);
(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
.value("DE_INTER_LINEAR", InterpolationMode::kLinear)
......
......@@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT
manifest_op.cc
cifar_op.cc
celeba_op.cc
text_file_op.cc
)
add_dependencies(engine-datasetops-source mindspore::protobuf)
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include "common/utils.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/core/config_manager.h"
#include "dataset/util/task_manager.h"
#include "dataset/util/wait_post.h"
#include "dataset/util/random.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
TextFileOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
builder_rows_per_buffer_ = config_manager->rows_per_buffer();
builder_worker_connector_size_ = config_manager->worker_connector_size();
}
Status TextFileOp::Builder::ValidateInputs() const {
std::string err_msg;
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : "";
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
RETURN_IF_NOT_OK(ValidateInputs());
// Throttle the number of workers if we have more workers than files!
if (static_cast<size_t>(builder_num_workers_) > builder_text_files_list_.size()) {
builder_num_workers_ = builder_text_files_list_.size();
MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers.";
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(text_file_op->Init());
*op = std::move(text_file_op);
return Status::OK();
}
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
num_samples_(num_samples),
text_files_list_(std::move(text_files_list)),
shuffle_files_(shuffle_files),
data_schema_(std::move(schema)),
all_num_rows_(0),
num_rows_per_shard_(0),
filename_index_(std::make_unique<StringIndex>()),
finished_reading_dataset_(false),
load_io_block_queue_(true),
load_jagged_connector_(true) {
worker_connector_size_ = worker_connector_size;
}
Status TextFileOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_));
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(text_files_list_.size() / num_workers_) + 1);
io_block_queues_.Init(num_workers_, safe_queue_size);
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
col_name_map_[data_schema_->column(i).name()] = i;
}
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
return Status::OK();
}
Status TextFileOp::Reset() {
load_jagged_connector_ = true;
load_io_block_queue_ = true;
RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}
Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) {
TensorRow tRow(1, nullptr);
(*tensor_table)->push_back(std::move(tRow));
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
Tensor::CreateTensor(&tensor, data_schema_->column(0).tensorImpl(),
TensorShape(std::vector<dsize_t>(1, line.size())), data_schema_->column(0).type(),
const_cast<unsigned char *>(reinterpret_cast<const unsigned char *>(common::SafeCStr(line)))));
(**tensor_table)[row][0] = std::move(tensor);
return Status::OK();
}
Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
std::ifstream handle(file);
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Failed to open file " + file);
}
int64_t rows_each_buffer = 0;
int64_t rows_total = 0;
std::string line;
std::unique_ptr<DataBuffer> cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
cur_buffer->set_column_name_map(col_name_map_);
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
while (getline(handle, line)) {
// If read to the end offset of this file, break.
if (rows_total >= end_offset) {
break;
}
// Skip line before start offset.
if (rows_total < start_offset) {
rows_total++;
continue;
}
RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer));
rows_each_buffer++;
rows_total++;
if (rows_each_buffer == rows_per_buffer_) {
cur_buffer->set_tensor_table(std::move(tensor_table));
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
cur_buffer->set_column_name_map(col_name_map_);
tensor_table = std::make_unique<TensorQTable>();
rows_each_buffer = 0;
}
}
if (rows_each_buffer > 0) {
cur_buffer->set_tensor_table(std::move(tensor_table));
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
}
return Status::OK();
}
Status TextFileOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}
// Pops an element from a queue in io_block_queues
Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}
// Pushes an element to a queue in io_block_queues
Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status TextFileOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}
return Status::OK();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status TextFileOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}
return Status::OK();
}
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}
bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
return push;
}
Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
int64_t pre_count = 0;
int64_t start_offset = 0;
int64_t end_offset = 0;
bool finish = false;
while (!finish) {
std::vector<std::pair<std::string, int64_t>> file_index;
if (!i_keys.empty()) {
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
auto file_it = filename_index_->Search(*it);
file_index.emplace_back(std::pair<std::string, int64_t>(file_it.value(), *it));
}
} else {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
}
}
for (auto file_info : file_index) {
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
auto ioBlock =
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
queue_index = (queue_index + 1) % num_workers_;
}
pre_count += filename_numrows_[file_info.first];
}
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
finish = false;
} else {
finish = true;
}
}
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
return Status::OK();
}
Status TextFileOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
if (finished_reading_dataset_) {
break;
}
if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}
void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
Status TextFileOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this)));
// Read data from disk into buffers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1)));
// must be called after launching workers.
TaskManager::FindMe()->Post();
io_block_queue_wait_post_.Register(tree_->AllTasks());
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;
while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
}
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}
int64_t TextFileOp::CountTotalRows(const std::string &file) {
std::ifstream handle(file);
if (!handle.is_open()) {
MS_LOG(ERROR) << "Failed to open file: " << file;
return 0;
}
std::string line;
int64_t count = 0;
while (getline(handle, line)) {
count++;
}
return count;
}
Status TextFileOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count;
all_num_rows_ += count;
}
if (all_num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED("Number of rows can not be zero");
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK();
}
Status TextFileOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
std::shared_ptr<TextFileOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op));
for (auto file : files) {
*count += op->CountTotalRows(file);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
#include <memory>
#include <map>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "dataset/util/status.h"
#include "dataset/util/auto_index.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/queue.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/jagged_connector.h"
namespace mindspore {
namespace dataset {
using StringIndex = AutoIndexObj<std::string>;
class TextFileOp : public ParallelOp {
public:
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
// Checks if the inputs of the builder is valid.
// @return Status - the error code returned.
Status ValidateInputs() const;
// Create the final object.
// @param op - dataset op.
// @return - the error code return.
Status Build(std::shared_ptr<TextFileOp> *op);
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumDevices(int64_t num_dev) {
builder_num_devices_ = num_dev;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetDeviceId(int64_t dev_id) {
builder_device_id_ = dev_id;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetTextFilesList(const std::vector<std::string> &files_list) {
builder_text_files_list_ = files_list;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetShuffleFiles(bool shuffle_files) {
builder_shuffle_files_ = shuffle_files;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
int32_t builder_worker_connector_size_;
std::vector<std::string> builder_text_files_list_;
bool builder_shuffle_files_;
std::unique_ptr<DataSchema> builder_schema_;
};
// Constructor of TextFileOp
// @note The builder class should be used to call this constructor.
// @param num_workers - number of worker threads reading data from tf_file files.
// @param rows_per_buffer - number of rows that a full buffer will contain.
// @param total_num_rows - number of rows to read
// @param dataset_files_list - list of filepaths for the dataset files.
// @param data_schema - the data schema object.
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~TextFileOp() = default;
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status Init();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Get total rows in files.
// @param files - all text files.
// @param count - number of rows.
// @return Status - the error coed returned.
static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count);
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
// @param row - the id of the row filled in the tensor table.
// @return Status - the error code returned.
Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row);
// Reads a text file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();
// Count number of rows in each file.
// @param filename - text file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t num_samples_;
std::vector<std::string> text_files_list_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_;
int64_t all_num_rows_;
int64_t num_rows_per_shard_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
WaitPost io_block_queue_wait_post_;
bool finished_reading_dataset_;
bool load_io_block_queue_;
bool load_jagged_connector_;
std::unordered_map<std::string, int32_t> col_name_map_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
......@@ -20,8 +20,8 @@ can also create samplers with this module to sample data.
from .core.configuration import config
from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \
Shuffle, zip
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
Schema, Shuffle, zip
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler
from .engine.serializer_deserializer import serialize, deserialize, show
......@@ -29,5 +29,5 @@ from .engine.serializer_deserializer import serialize, deserialize, show
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip"]
......@@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset",
"ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler",
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]
......@@ -29,7 +29,7 @@ from importlib import import_module
import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
MindRecordOp, CBatchInfo
MindRecordOp, TextFileOp, CBatchInfo
from mindspore._c_expression import typing
from mindspore import log as logger
......@@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_zip_dataset, check_add_column
check_zip_dataset, check_add_column, check_textfiledataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
......@@ -888,6 +888,29 @@ class SourceDataset(Dataset):
# No need for __init__ since it is the same as the super's init
@staticmethod
def _find_files(patterns):
"""
Utility function to search for files with the given glob patterns.
Args:
patterns (str or list[str]): string or list of patterns to be searched.
Returns:
List, files.
"""
def flat(lists):
return list(np.array(lists).flatten())
if not isinstance(patterns, list):
patterns = [patterns]
file_list = flat([glob.glob(file, recursive=True) for file in patterns])
if file_list: # not empty
return file_list
raise ValueError("The list of path names matching the patterns is empty.")
class DatasetOp(Dataset):
"""
......@@ -2126,30 +2149,6 @@ class TFRecordDataset(SourceDataset):
>>> # 3) get all rows from dataset_files with schema file "./schema.json":
>>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
"""
@staticmethod
def _find_files(patterns):
"""
Utility function to search for files with the given glob patterns.
Args:
patterns (str or list[str]): string or list of patterns to be searched.
Returns:
List, files.
"""
def flat(lists):
return list(np.array(lists).flatten())
if not isinstance(patterns, list):
patterns = [patterns]
file_list = flat([glob.glob(file, recursive=True) for file in patterns])
if file_list: # not empty
return file_list
raise ValueError("The list of path names matching the patterns is empty.")
@check_tfrecorddataset
def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False):
......@@ -2952,3 +2951,82 @@ class CelebADataset(SourceDataset):
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args
class TextFileDataset(SourceDataset):
"""
A source dataset that reads and parses datasets stored on disk in text format.
The generated dataset has one columns ['text'].
Args:
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
files. The list will be sorted in a lexicographical order.
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Examples:
>>> import mindspore.dataset as ds
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
>>> dataset = ds.TextFileDataset(dataset_files=dataset_files)
"""
@check_textfiledataset
def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None,
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
self.num_samples = num_samples
if not isinstance(shuffle, (bool, Shuffle)):
raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
if not isinstance(shuffle, Shuffle):
if shuffle:
self.shuffle_level = Shuffle.GLOBAL
self.shuffle_files = True
else:
self.shuffle_level = None
self.shuffle_files = False
else:
self.shuffle_level = shuffle
self.shuffle_files = True
self.num_shards = num_shards
self.shard_id = shard_id
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
args["num_samples"] = self.num_samples
if self.shuffle_files is not None:
args["shuffle_files"] = self.shuffle_files
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if self._dataset_size is None:
num_rows = TextFileOp.get_num_rows(self.dataset_files)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples is None:
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
......@@ -48,12 +48,16 @@ def alter_tree(node):
def _alter_node(node):
"""Performing some alteration to a dataset node. A common alteration is to insert a node."""
if isinstance(node, de.TFRecordDataset) and node.shuffle_level == de.Shuffle.GLOBAL:
if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL:
# Remove the connection between the parent's node to the current node because we are inserting a node.
if node.output:
node.output.pop()
# Perform a fast scan for average rows per file
avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files)
if isinstance(node, de.TFRecordDataset):
avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files)
else:
avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files)
# Shuffle between 4 files with a minimum size of 10000 rows
new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000))
return new_shuffle
......@@ -157,6 +161,8 @@ class Iterator:
op_type = OpName.CIFAR100
elif isinstance(dataset, de.CelebADataset):
op_type = OpName.CELEBA
elif isinstance(dataset, de.TextFileDataset):
op_type = OpName.TEXTFILE
else:
raise ValueError("Unsupported DatasetOp")
......
......@@ -849,3 +849,25 @@ def check_add_column(method):
return method(*args, **kwargs)
return new_method
def check_textfiledataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_files; required argument
dataset_files = param_dict.get('dataset_files')
if dataset_files is None:
raise ValueError("dataset_files is not provided.")
if not isinstance(dataset_files, (str, list)):
raise TypeError("dataset_files should be of type str or a list of strings.")
check_param_type(nreq_param_int, param_dict, int)
return method(*args, **kwargs)
return new_method
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module is to support nlp augmentations. It includes two parts:
c_transforms and py_transforms. C_transforms is a high performance
image augmentation module which is developed with c++ opencv. Py_transforms
provide more kinds of image augmentations which is developed with python PIL.
"""
from .utils import as_text
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Some basic function for nlp
"""
import numpy as np
def as_text(array, encoding='utf8'):
"""
Convert data of array to unicode.
Args:
array (numpy array): Data of array should be ASCII values of each character after converted.
encoding (string): Indicating the charset for decoding.
Returns:
A 'str' object.
"""
if not isinstance(array, np.ndarray):
raise ValueError('input should be a numpy array')
byte_array = bytearray(list(array))
return byte_array.decode(encoding)
......@@ -65,7 +65,7 @@ SET(DE_UT_SRCS
cifar_op_test.cc
celeba_op_test.cc
take_op_test.cc
)
text_file_op_test.cc)
add_executable(de_ut_tests ${DE_UT_SRCS})
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/util/status.h"
namespace common = mindspore::common;
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestTextFileOp : public UT::DatasetOpTesting {
};
TEST_F(MindDataTestTextFileOp, TestTextFileBasic) {
// Start with an empty execution tree
auto tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testTextFileDataset/1.txt";
std::shared_ptr<TextFileOp> op;
TextFileOp::Builder builder;
builder.SetTextFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetNumWorkers(16)
.SetOpConnectorSize(2);
Status rc = builder.Build(&op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssignRoot(op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 3);
}
TEST_F(MindDataTestTextFileOp, TestTotalRows) {
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
std::vector<std::string> files;
files.push_back(tf_file1);
int64_t total_rows = 0;
TextFileOp::CountAllFileRows(files, &total_rows);
ASSERT_EQ(total_rows, 3);
files.clear();
files.push_back(tf_file2);
TextFileOp::CountAllFileRows(files, &total_rows);
ASSERT_EQ(total_rows, 2);
files.clear();
files.push_back(tf_file1);
files.push_back(tf_file2);
TextFileOp::CountAllFileRows(files, &total_rows);
ASSERT_EQ(total_rows, 5);
files.clear();
}
This is a text file.
Be happy every day.
Good luck to everyone.
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
import mindspore.dataset.transforms.nlp.utils as nlp
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
def test_textline_dataset_one_file():
data = ds.TextFileDataset(DATA_FILE)
count = 0
for i in data.create_dict_iterator():
logger.info("{}".format(i["text"]))
count += 1
assert(count == 3)
def test_textline_dataset_all_file():
data = ds.TextFileDataset(DATA_ALL_FILE)
count = 0
for i in data.create_dict_iterator():
logger.info("{}".format(i["text"]))
count += 1
assert(count == 5)
def test_textline_dataset_totext():
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
count = 0
line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."]
for i in data.create_dict_iterator():
str = nlp.as_text(i["text"])
assert(str == line[count])
count += 1
assert(count == 5)
def test_textline_dataset_num_samples():
data = ds.TextFileDataset(DATA_FILE, num_samples=2)
count = 0
for i in data.create_dict_iterator():
count += 1
assert(count == 2)
def test_textline_dataset_distribution():
data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1)
count = 0
for i in data.create_dict_iterator():
count += 1
assert(count == 3)
def test_textline_dataset_repeat():
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
data = data.repeat(3)
count = 0
line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
"This is a text file.", "Be happy every day.", "Good luck to everyone.",
"This is a text file.", "Be happy every day.", "Good luck to everyone."]
for i in data.create_dict_iterator():
str = nlp.as_text(i["text"])
assert(str == line[count])
count += 1
assert(count == 9)
def test_textline_dataset_get_datasetsize():
data = ds.TextFileDataset(DATA_FILE)
size = data.get_dataset_size()
assert(size == 3)
if __name__ == "__main__":
test_textline_dataset_one_file()
test_textline_dataset_all_file()
test_textline_dataset_totext()
test_textline_dataset_num_samples()
test_textline_dataset_distribution()
test_textline_dataset_repeat()
test_textline_dataset_get_datasetsize()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册