提交 7f39b5cf 编写于 作者: C Cathy Wong

C++ API Support for TextFile Dataset and Unit Tests

上级 4f75adb1
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "minddata/dataset/engine/datasetops/source/coco_op.h" #include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h" #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h" #include "minddata/dataset/engine/datasetops/source/voc_op.h"
// Dataset operator headers (in alphabetical order) // Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h" #include "minddata/dataset/engine/datasetops/batch_op.h"
...@@ -95,6 +96,7 @@ Dataset::Dataset() { ...@@ -95,6 +96,7 @@ Dataset::Dataset() {
num_workers_ = cfg->num_parallel_workers(); num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer(); rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size(); connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
} }
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
...@@ -140,7 +142,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str ...@@ -140,7 +142,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode, std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions, std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
std::map<std::string, int32_t> class_indexing) { std::map<std::string, int32_t> class_indexing) {
// This arg is exist in ImageFolderOp, but not externalized (in Python API). The default value is false. // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
bool recursive = false; bool recursive = false;
// Create logical representation of ImageFolderDataset. // Create logical representation of ImageFolderDataset.
...@@ -163,6 +165,16 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset ...@@ -163,6 +165,16 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
const std::shared_ptr<Dataset> &datasets2) { const std::shared_ptr<Dataset> &datasets2) {
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2})); std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a TextFileDataset.
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) {
auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr; return ds->ValidateParams() ? ds : nullptr;
} }
...@@ -340,6 +352,34 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() { ...@@ -340,6 +352,34 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
return std::make_shared<RandomSamplerObj>(replacement, num_samples); return std::make_shared<RandomSamplerObj>(replacement, num_samples);
} }
// Helper function to compute a default shuffle size
int64_t ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows) {
const int64_t average_files_multiplier = 4;
const int64_t shuffle_max = 10000;
int64_t avg_rows_per_file = 0;
int64_t shuffle_size = 0;
// Adjust the num rows per shard if sharding was given
if (num_devices > 0) {
if (num_rows % num_devices == 0) {
num_rows = num_rows / num_devices;
} else {
num_rows = (num_rows / num_devices) + 1;
}
}
// Cap based on total rows directive. Some ops do not have this and give value of 0.
if (total_rows > 0) {
num_rows = std::min(num_rows, total_rows);
}
// get the average per file
avg_rows_per_file = num_rows / num_files;
shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
return shuffle_size;
}
// Helper function to validate dataset params // Helper function to validate dataset params
bool ValidateCommonDatasetParams(std::string dataset_dir) { bool ValidateCommonDatasetParams(std::string dataset_dir) {
if (dataset_dir.empty()) { if (dataset_dir.empty()) {
...@@ -613,6 +653,87 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() { ...@@ -613,6 +653,87 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
return node_ops; return node_ops;
} }
// Constructor for TextFileDataset
TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id)
: dataset_files_(dataset_files),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {}
bool TextFileDataset::ValidateParams() {
if (dataset_files_.empty()) {
MS_LOG(ERROR) << "TextFileDataset: dataset_files is not specified.";
return false;
}
for (auto file : dataset_files_) {
std::ifstream handle(file);
if (!handle.is_open()) {
MS_LOG(ERROR) << "TextFileDataset: Failed to open file: " << file;
return false;
}
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_;
return false;
}
if (num_shards_ <= 0) {
MS_LOG(ERROR) << "TextFileDataset: Invalid num_shards: " << num_shards_;
return false;
}
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
MS_LOG(ERROR) << "TextFileDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
return false;
}
return true;
}
// Function to build TextFileDataset
std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
// Create and initalize TextFileOp
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), dataset_files_,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr));
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t shuffle_size = 0;
int64_t num_rows = 0;
// First, get the number of rows in the dataset and then compute the shuffle size
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(dataset_files_, &num_rows));
shuffle_size = ComputeShuffleSize(dataset_files_.size(), num_shards_, num_rows, 0);
MS_LOG(INFO) << "TextFileDataset::Build - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
// Add the shuffle op after this op
shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size_, true, rows_per_buffer_);
node_ops.push_back(shuffle_op);
}
// Add TextFileOp
node_ops.push_back(text_file_op);
return node_ops;
}
// Constructor for VOCDataset // Constructor for VOCDataset
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode, const std::map<std::string, int32_t> &class_index, bool decode,
......
...@@ -35,6 +35,9 @@ enum class DatasetType { kUnknown, kArrow, kTf }; ...@@ -35,6 +35,9 @@ enum class DatasetType { kUnknown, kArrow, kTf };
// Possible flavours of Tensor implementations // Possible flavours of Tensor implementations
enum class TensorImpl { kNone, kFlexible, kCv, kNP }; enum class TensorImpl { kNone, kFlexible, kCv, kNP };
// Possible values for shuffle
enum class ShuffleMode { kFalse = 0, kFiles = 1, kGlobal = 2 };
// Possible values for Border types // Possible values for Border types
enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 }; enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 };
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <map> #include <map>
#include <utility> #include <utility>
#include <string> #include <string>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/include/tensor.h" #include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/include/iterator.h" #include "minddata/dataset/include/iterator.h"
#include "minddata/dataset/include/samplers.h" #include "minddata/dataset/include/samplers.h"
...@@ -47,6 +48,7 @@ class Cifar100Dataset; ...@@ -47,6 +48,7 @@ class Cifar100Dataset;
class CocoDataset; class CocoDataset;
class ImageFolderDataset; class ImageFolderDataset;
class MnistDataset; class MnistDataset;
class TextFileDataset;
class VOCDataset; class VOCDataset;
// Dataset Op classes (in alphabetical order) // Dataset Op classes (in alphabetical order)
class BatchDataset; class BatchDataset;
...@@ -83,7 +85,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std: ...@@ -83,7 +85,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr); std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
/// \brief Function to create a Cifar100 Dataset /// \brief Function to create a Cifar100 Dataset
/// \notes The generated dataset has two columns ['image', 'coarse_label', 'fine_label'] /// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset /// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset /// will be used to randomly iterate the entire dataset
...@@ -143,6 +145,25 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam ...@@ -143,6 +145,25 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1, std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
const std::shared_ptr<Dataset> &datasets2); const std::shared_ptr<Dataset> &datasets2);
/// \brief Function to create a TextFileDataset
/// \notes The generated dataset has one column ['text']
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current TextFileDataset
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
/// \brief Function to create a VOCDataset /// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns : /// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32], /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
...@@ -289,10 +310,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> { ...@@ -289,10 +310,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
int32_t num_workers_; int32_t num_workers_;
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
int32_t connector_que_size_; int32_t connector_que_size_;
int32_t worker_connector_size_;
}; };
/* ####################################### Derived Dataset classes ################################# */ /* ####################################### Derived Dataset classes ################################# */
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
// (In alphabetical order)
class CelebADataset : public Dataset { class CelebADataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
...@@ -318,6 +343,8 @@ class CelebADataset : public Dataset { ...@@ -318,6 +343,8 @@ class CelebADataset : public Dataset {
std::set<std::string> extensions_; std::set<std::string> extensions_;
std::shared_ptr<SamplerObj> sampler_; std::shared_ptr<SamplerObj> sampler_;
}; };
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
// (In alphabetical order)
class Cifar10Dataset : public Dataset { class Cifar10Dataset : public Dataset {
public: public:
...@@ -435,6 +462,33 @@ class MnistDataset : public Dataset { ...@@ -435,6 +462,33 @@ class MnistDataset : public Dataset {
std::shared_ptr<SamplerObj> sampler_; std::shared_ptr<SamplerObj> sampler_;
}; };
/// \class TextFileDataset
/// \brief A Dataset derived class to represent TextFile dataset
class TextFileDataset : public Dataset {
public:
/// \brief Constructor
TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
int32_t shard_id);
/// \brief Destructor
~TextFileDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::vector<std::string> dataset_files_;
int32_t num_samples_;
int32_t num_shards_;
int32_t shard_id_;
ShuffleMode shuffle_;
};
class VOCDataset : public Dataset { class VOCDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
...@@ -467,6 +521,9 @@ class VOCDataset : public Dataset { ...@@ -467,6 +521,9 @@ class VOCDataset : public Dataset {
std::shared_ptr<SamplerObj> sampler_; std::shared_ptr<SamplerObj> sampler_;
}; };
// DERIVED DATASET CLASSES FOR DATASET OPS
// (In alphabetical order)
class BatchDataset : public Dataset { class BatchDataset : public Dataset {
public: public:
/// \brief Constructor /// \brief Constructor
......
...@@ -5012,7 +5012,7 @@ class CSVDataset(SourceDataset): ...@@ -5012,7 +5012,7 @@ class CSVDataset(SourceDataset):
class TextFileDataset(SourceDataset): class TextFileDataset(SourceDataset):
""" """
A source dataset that reads and parses datasets stored on disk in text format. A source dataset that reads and parses datasets stored on disk in text format.
The generated dataset has one columns ['text']. The generated dataset has one column ['text'].
Args: Args:
dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a
......
...@@ -97,6 +97,7 @@ SET(DE_UT_SRCS ...@@ -97,6 +97,7 @@ SET(DE_UT_SRCS
c_api_dataset_ops_test.cc c_api_dataset_ops_test.cc
c_api_dataset_cifar_test.cc c_api_dataset_cifar_test.cc
c_api_dataset_coco_test.cc c_api_dataset_coco_test.cc
c_api_dataset_filetext_test.cc
c_api_dataset_voc_test.cc c_api_dataset_voc_test.cc
c_api_datasets_test.cc c_api_datasets_test.cc
c_api_dataset_iterator_test.cc c_api_dataset_iterator_test.cc
......
此差异已折叠。
...@@ -89,6 +89,23 @@ TEST_F(MindDataTestTextFileOp, TestTextFileBasic) { ...@@ -89,6 +89,23 @@ TEST_F(MindDataTestTextFileOp, TestTextFileBasic) {
ASSERT_EQ(row_count, 3); ASSERT_EQ(row_count, 3);
} }
TEST_F(MindDataTestTextFileOp, TestTextFileFileNotExist) {
// Start with an empty execution tree
auto tree = std::make_shared<ExecutionTree>();
std::string dataset_path = datasets_root_path_ + "/does/not/exist/0.txt";
std::shared_ptr<TextFileOp> op;
TextFileOp::Builder builder;
builder.SetTextFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetNumWorkers(16)
.SetOpConnectorSize(2);
Status rc = builder.Build(&op);
ASSERT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestTextFileOp, TestTotalRows) { TEST_F(MindDataTestTextFileOp, TestTotalRows) {
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
...@@ -110,3 +127,14 @@ TEST_F(MindDataTestTextFileOp, TestTotalRows) { ...@@ -110,3 +127,14 @@ TEST_F(MindDataTestTextFileOp, TestTotalRows) {
ASSERT_EQ(total_rows, 5); ASSERT_EQ(total_rows, 5);
files.clear(); files.clear();
} }
TEST_F(MindDataTestTextFileOp, TestTotalRowsFileNotExist) {
std::string tf_file1 = datasets_root_path_ + "/does/not/exist/0.txt";
std::vector<std::string> files;
files.push_back(tf_file1);
int64_t total_rows = 0;
TextFileOp::CountAllFileRows(files, &total_rows);
ASSERT_EQ(total_rows, 0);
}
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from util import config_get_set_num_parallel_workers from util import config_get_set_num_parallel_workers, config_get_set_seed
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
...@@ -39,10 +40,54 @@ def test_textline_dataset_all_file(): ...@@ -39,10 +40,54 @@ def test_textline_dataset_all_file():
assert count == 5 assert count == 5
def test_textline_dataset_totext(): def test_textline_dataset_num_samples_zero():
data = ds.TextFileDataset(DATA_FILE, num_samples=0)
count = 0
for i in data.create_dict_iterator():
logger.info("{}".format(i["text"]))
count += 1
assert count == 3
def test_textline_dataset_shuffle_false4():
original_num_parallel_workers = config_get_set_num_parallel_workers(4) original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(987)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
count = 0
line = ["This is a text file.", "Another file.",
"Be happy every day.", "End of file.", "Good luck to everyone."]
for i in data.create_dict_iterator():
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 5
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_textline_dataset_shuffle_false1():
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(987)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
count = 0 count = 0
line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
"Another file.", "End of file."]
for i in data.create_dict_iterator():
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 5
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_textline_dataset_shuffle_files4():
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(135)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
count = 0
line = ["This is a text file.", "Another file.", line = ["This is a text file.", "Another file.",
"Be happy every day.", "End of file.", "Good luck to everyone."] "Be happy every day.", "End of file.", "Good luck to everyone."]
for i in data.create_dict_iterator(): for i in data.create_dict_iterator():
...@@ -50,8 +95,60 @@ def test_textline_dataset_totext(): ...@@ -50,8 +95,60 @@ def test_textline_dataset_totext():
assert strs == line[count] assert strs == line[count]
count += 1 count += 1
assert count == 5 assert count == 5
# Restore configuration num_parallel_workers # Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers) ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_textline_dataset_shuffle_files1():
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(135)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
count = 0
line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
"Another file.", "End of file."]
for i in data.create_dict_iterator():
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 5
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_textline_dataset_shuffle_global4():
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(246)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL)
count = 0
line = ["Another file.", "Good luck to everyone.", "End of file.",
"This is a text file.", "Be happy every day."]
for i in data.create_dict_iterator():
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 5
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_textline_dataset_shuffle_global1():
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(246)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL)
count = 0
line = ["Another file.", "Good luck to everyone.", "This is a text file.",
"End of file.", "Be happy every day."]
for i in data.create_dict_iterator():
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 5
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_textline_dataset_num_samples(): def test_textline_dataset_num_samples():
...@@ -94,11 +191,33 @@ def test_textline_dataset_to_device(): ...@@ -94,11 +191,33 @@ def test_textline_dataset_to_device():
data = data.to_device() data = data.to_device()
data.send() data.send()
def test_textline_dataset_exceptions():
with pytest.raises(ValueError) as error_info:
_ = ds.TextFileDataset(DATA_FILE, num_samples=-1)
assert "Input num_samples is not within the required interval" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.TextFileDataset("does/not/exist/no.txt")
assert "The following patterns did not match any files" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.TextFileDataset("")
assert "The following patterns did not match any files" in str(error_info.value)
if __name__ == "__main__": if __name__ == "__main__":
test_textline_dataset_one_file() test_textline_dataset_one_file()
test_textline_dataset_all_file() test_textline_dataset_all_file()
test_textline_dataset_totext() test_textline_dataset_num_samples_zero()
test_textline_dataset_shuffle_false4()
test_textline_dataset_shuffle_false1()
test_textline_dataset_shuffle_files4()
test_textline_dataset_shuffle_files1()
test_textline_dataset_shuffle_global4()
test_textline_dataset_shuffle_global1()
test_textline_dataset_num_samples() test_textline_dataset_num_samples()
test_textline_dataset_distribution() test_textline_dataset_distribution()
test_textline_dataset_repeat() test_textline_dataset_repeat()
test_textline_dataset_get_datasetsize() test_textline_dataset_get_datasetsize()
test_textline_dataset_to_device()
test_textline_dataset_exceptions()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册