提交 ea947568 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5605 Introduce usage flag to MNIST and CIFAR dataset

Merge pull request !5605 from ZiruiWu/add_usage_to_cifar_mnist_coco
......@@ -15,7 +15,7 @@
*/
#include <fstream>
#include <unordered_set>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
......@@ -132,26 +132,28 @@ std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::s
}
// Function to create a CelebADataset.
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type,
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, bool decode,
const std::set<std::string> &extensions) {
auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions);
auto ds = std::make_shared<CelebADataset>(dataset_dir, usage, sampler, decode, extensions);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a Cifar10Dataset.
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, usage, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a Cifar100Dataset.
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler);
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, usage, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
......@@ -217,8 +219,9 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const
#endif
// Function to create a MnistDataset.
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler);
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<MnistDataset>(dataset_dir, usage, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
......@@ -244,10 +247,10 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
#ifndef ENABLE_ANDROID
// Function to create a VOCDataset.
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode,
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_indexing, decode, sampler);
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, usage, class_indexing, decode, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
......@@ -727,6 +730,10 @@ bool ValidateDatasetSampler(const std::string &dataset_name, const std::shared_p
return true;
}
bool ValidateStringValue(const std::string &str, const std::unordered_set<std::string> &valid_strings) {
return valid_strings.find(str) != valid_strings.end();
}
// Helper function to validate dataset input/output column parameter
bool ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
const std::vector<std::string> &columns) {
......@@ -802,29 +809,14 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumDataset::Build() {
}
// Constructor for CelebADataset
CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
const std::set<std::string> &extensions)
: dataset_dir_(dataset_dir),
dataset_type_(dataset_type),
sampler_(sampler),
decode_(decode),
extensions_(extensions) {}
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {}
bool CelebADataset::ValidateParams() {
if (!ValidateDatasetDirParam("CelebADataset", dataset_dir_)) {
return false;
}
if (!ValidateDatasetSampler("CelebADataset", sampler_)) {
return false;
}
std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"};
auto iter = dataset_type_list.find(dataset_type_);
if (iter == dataset_type_list.end()) {
MS_LOG(ERROR) << "dataset_type should be one of 'all', 'train', 'valid' or 'test'.";
return false;
}
return true;
return ValidateDatasetDirParam("CelebADataset", dataset_dir_) && ValidateDatasetSampler("CelebADataset", sampler_) &&
ValidateStringValue(usage_, {"all", "train", "valid", "test"});
}
// Function to build CelebADataset
......@@ -839,17 +831,20 @@ std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() {
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, dataset_type_, extensions_, std::move(schema),
decode_, usage_, extensions_, std::move(schema),
std::move(sampler_->Build())));
return node_ops;
}
// Constructor for Cifar10Dataset
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), sampler_(sampler) {}
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
bool Cifar10Dataset::ValidateParams() {
return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_) && ValidateDatasetSampler("Cifar10Dataset", sampler_);
return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_) &&
ValidateDatasetSampler("Cifar10Dataset", sampler_) &&
ValidateStringValue(usage_, {"train", "test", "all", ""});
}
// Function to build CifarOp for Cifar10
......@@ -864,19 +859,21 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build())));
return node_ops;
}
// Constructor for Cifar100Dataset
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), sampler_(sampler) {}
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::string &usage,
std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
bool Cifar100Dataset::ValidateParams() {
return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_) &&
ValidateDatasetSampler("Cifar100Dataset", sampler_);
ValidateDatasetSampler("Cifar100Dataset", sampler_) &&
ValidateStringValue(usage_, {"train", "test", "all", ""});
}
// Function to build CifarOp for Cifar100
......@@ -893,7 +890,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_,
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build())));
return node_ops;
......@@ -1360,11 +1357,12 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestDataset::Build() {
}
#endif
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), sampler_(sampler) {}
MnistDataset::MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
bool MnistDataset::ValidateParams() {
return ValidateDatasetDirParam("MnistDataset", dataset_dir_) && ValidateDatasetSampler("MnistDataset", sampler_);
return ValidateStringValue(usage_, {"train", "test", "all", ""}) &&
ValidateDatasetDirParam("MnistDataset", dataset_dir_) && ValidateDatasetSampler("MnistDataset", sampler_);
}
std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
......@@ -1378,8 +1376,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
std::move(schema), std::move(sampler_->Build())));
node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->Build())));
return node_ops;
}
......@@ -1570,12 +1568,12 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordDataset::Build() {
#ifndef ENABLE_ANDROID
// Constructor for VOCDataset
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode,
std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir),
task_(task),
mode_(mode),
usage_(usage),
class_index_(class_indexing),
decode_(decode),
sampler_(sampler) {}
......@@ -1594,15 +1592,15 @@ bool VOCDataset::ValidateParams() {
MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task.";
return false;
}
Path imagesets_file = dir / "ImageSets" / "Segmentation" / mode_ + ".txt";
Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt";
if (!imagesets_file.Exists()) {
MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
MS_LOG(ERROR) << "Invalid mode: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
return false;
}
} else if (task_ == "Detection") {
Path imagesets_file = dir / "ImageSets" / "Main" / mode_ + ".txt";
Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt";
if (!imagesets_file.Exists()) {
MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
MS_LOG(ERROR) << "Invalid mode: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
return false;
}
} else {
......@@ -1641,7 +1639,7 @@ std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() {
}
std::shared_ptr<VOCOp> voc_op;
voc_op = std::make_shared<VOCOp>(task_type_, mode_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
node_ops.push_back(voc_op);
return node_ops;
......
......@@ -41,9 +41,9 @@ namespace dataset {
PYBIND_REGISTER(CifarOp, 1, ([](const py::module *m) {
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
.def_static("get_num_rows", [](const std::string &dir, const std::string &usage, bool isCifar10) {
int64_t count = 0;
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, usage, isCifar10, &count));
return count;
});
}));
......@@ -131,9 +131,9 @@ PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) {
PYBIND_REGISTER(MnistOp, 1, ([](const py::module *m) {
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
.def_static("get_num_rows", [](const std::string &dir) {
.def_static("get_num_rows", [](const std::string &dir, const std::string &usage) {
int64_t count = 0;
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, usage, &count));
return count;
});
}));
......
......@@ -1354,25 +1354,14 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
if (args["dataset_dir"].is_none()) {
std::string err_msg = "Error: No dataset path specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["task"].is_none()) {
std::string err_msg = "Error: No task specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["mode"].is_none()) {
std::string err_msg = "Error: No mode specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
CHECK_FAIL_RETURN_UNEXPECTED(!args["dataset_dir"].is_none(), "Error: No dataset path specified.");
CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified.");
CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified.");
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"]));
(void)builder->SetTask(ToString(args["task"]));
(void)builder->SetMode(ToString(args["mode"]));
(void)builder->SetUsage(ToString(args["usage"]));
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
......@@ -1461,6 +1450,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
}
}
}
......@@ -1495,6 +1486,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
}
}
}
......@@ -1608,6 +1601,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
}
}
}
......@@ -1645,8 +1640,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(void)builder->SetDecode(ToBool(value));
} else if (key == "extensions") {
(void)builder->SetExtensions(ToStringSet(value));
} else if (key == "dataset_type") {
(void)builder->SetDatasetType(ToString(value));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
}
}
}
......
......@@ -36,7 +36,7 @@ CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr)
Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
MS_LOG(DEBUG) << "Celeba dataset directory is " << builder_dir_.c_str() << ".";
MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << ".";
MS_LOG(DEBUG) << "Celeba dataset type is " << builder_usage_.c_str() << ".";
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0;
......@@ -51,8 +51,8 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
*op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
builder_op_connector_size_, builder_decode_, builder_dataset_type_,
builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_));
builder_op_connector_size_, builder_decode_, builder_usage_, builder_extensions_,
std::move(builder_schema_), std::move(builder_sampler_));
if (*op == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null");
}
......@@ -69,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() {
}
CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size,
bool decode, const std::string &dataset_type, const std::set<std::string> &exts,
bool decode, const std::string &usage, const std::set<std::string> &exts,
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
......@@ -78,7 +78,7 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri
extensions_(exts),
data_schema_(std::move(schema)),
num_rows_in_attr_file_(0),
dataset_type_(dataset_type) {
usage_(usage) {
attr_info_queue_ = std::make_unique<Queue<std::vector<std::string>>>(queue_size);
io_block_queues_.Init(num_workers_, queue_size);
}
......@@ -135,7 +135,7 @@ Status CelebAOp::ParseAttrFile() {
std::vector<std::string> image_infos;
image_infos.reserve(oc_queue_size_);
while (getline(attr_file, image_info)) {
if ((image_info.empty()) || (dataset_type_ != "all" && !CheckDatasetTypeValid())) {
if ((image_info.empty()) || (usage_ != "all" && !CheckDatasetTypeValid())) {
continue;
}
image_infos.push_back(image_info);
......@@ -179,11 +179,11 @@ bool CelebAOp::CheckDatasetTypeValid() {
return false;
}
// train:0, valid=1, test=2
if (dataset_type_ == "train" && (type == 0)) {
if (usage_ == "train" && (type == 0)) {
return true;
} else if (dataset_type_ == "valid" && (type == 1)) {
} else if (usage_ == "valid" && (type == 1)) {
return true;
} else if (dataset_type_ == "test" && (type == 2)) {
} else if (usage_ == "test" && (type == 2)) {
return true;
}
......
......@@ -109,10 +109,10 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
}
// Setter method
// @param const std::string dataset_type: type to be read
// @param const std::string usage: type to be read
// @return Builder setter method returns reference to the builder.
Builder &SetDatasetType(const std::string &dataset_type) {
builder_dataset_type_ = dataset_type;
Builder &SetUsage(const std::string &usage) {
builder_usage_ = usage;
return *this;
}
// Check validity of input args
......@@ -133,7 +133,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
std::string builder_dataset_type_;
std::string builder_usage_;
};
// Constructor
......@@ -143,12 +143,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @param int32_t queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode,
const std::string &dataset_type, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
std::shared_ptr<Sampler> sampler);
~CelebAOp() override = default;
// Main Loop of CelebaOp
// Main Loop of CelebAOp
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
......@@ -177,7 +177,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// Op name getter
// @return Name of the current Op
std::string Name() const { return "CelebAOp"; }
std::string Name() const override { return "CelebAOp"; }
private:
// Called first when function is called
......@@ -232,7 +232,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
WaitPost wp_;
std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_;
std::string dataset_type_;
std::string usage_;
std::ifstream partition_file_;
};
} // namespace dataset
......
......@@ -18,15 +18,16 @@
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <set>
#include <utility>
#include "utils/ms_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
......@@ -36,7 +37,7 @@ constexpr uint32_t kCifarImageChannel = 3;
constexpr uint32_t kCifarBlockImageNum = 5;
constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel;
CifarOp::Builder::Builder() : sampler_(nullptr) {
CifarOp::Builder::Builder() : sampler_(nullptr), usage_("") {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
......@@ -65,23 +66,27 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar)));
}
*ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_,
*ptr = std::make_shared<CifarOp>(cifar_type_, usage_, num_workers_, rows_per_buffer_, dir_, op_connect_size_,
std::move(schema_), std::move(sampler_));
return Status::OK();
}
Status CifarOp::Builder::SanityCheck() {
const std::set<std::string> valid = {"test", "train", "all", ""};
Path dir(dir_);
std::string err_msg;
err_msg += dir.IsDirectory() == false ? "Cifar path is invalid or not set\n" : "";
err_msg += num_workers_ <= 0 ? "Num of parallel workers is negative or 0\n" : "";
err_msg += valid.find(usage_) == valid.end() ? "usage needs to be 'train','test' or 'all'\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir,
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf,
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
: ParallelOp(num_works, queue_size, std::move(sampler)),
cifar_type_(type),
usage_(usage),
rows_per_buffer_(rows_per_buf),
folder_path_(file_dir),
data_schema_(std::move(data_schema)),
......@@ -258,21 +263,32 @@ Status CifarOp::ReadCifarBlockDataAsync() {
}
Status CifarOp::ReadCifar10BlockData() {
// CIFAR 10 has 6 bin files. data_batch_1.bin ... data_batch_5.bin and 1 test_batch.bin file
// each of the file has exactly 10K images and labels and size is 30,730 KB
// each image has the dimension of 32 x 32 x 3 = 3072 plus 1 label (label has 10 classes) so each row has 3073 bytes
constexpr uint32_t num_cifar10_records = 10000;
uint32_t block_size = (kCifarImageSize + 1) * kCifarBlockImageNum; // about 2M
std::vector<unsigned char> image_data(block_size * sizeof(unsigned char), 0);
for (auto &file : cifar_files_) {
std::ifstream in(file, std::ios::binary);
if (!in.is_open()) {
std::string err_msg = file + " can not be opened.";
RETURN_STATUS_UNEXPECTED(err_msg);
// check the validity of the file path
Path file_path(file);
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file);
std::string file_name = file_path.Basename();
if (usage_ == "train") {
if (file_name.find("data_batch") == std::string::npos) continue;
} else if (usage_ == "test") {
if (file_name.find("test_batch") == std::string::npos) continue;
} else { // get all the files that contain the word batch, aka any cifar 100 files
if (file_name.find("batch") == std::string::npos) continue;
}
std::ifstream in(file, std::ios::binary);
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened.");
for (uint32_t index = 0; index < num_cifar10_records / kCifarBlockImageNum; ++index) {
(void)in.read(reinterpret_cast<char *>(&(image_data[0])), block_size * sizeof(unsigned char));
if (in.fail()) {
RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file);
}
CHECK_FAIL_RETURN_UNEXPECTED(!in.fail(), "Fail to read cifar file" + file);
(void)cifar_raw_data_block_->EmplaceBack(image_data);
}
in.close();
......@@ -283,15 +299,21 @@ Status CifarOp::ReadCifar10BlockData() {
}
Status CifarOp::ReadCifar100BlockData() {
// CIFAR 100 has 2 bin files. train.bin (60K imgs) 153,700KB and test.bin (30,740KB) (10K imgs)
// each img has two labels. Each row then is 32 * 32 *5 + 2 = 3,074 Bytes
uint32_t num_cifar100_records = 0; // test:10000, train:50000
uint32_t block_size = (kCifarImageSize + 2) * kCifarBlockImageNum; // about 2M
std::vector<unsigned char> image_data(block_size * sizeof(unsigned char), 0);
for (auto &file : cifar_files_) {
int pos = file.find_last_of('/');
if (pos == std::string::npos) {
RETURN_STATUS_UNEXPECTED("Invalid cifar100 file path");
}
std::string file_name(file.substr(pos + 1));
// check the validity of the file path
Path file_path(file);
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file);
std::string file_name = file_path.Basename();
// if usage is train/test, get only these 2 files
if (usage_ == "train" && file_name.find("train") == std::string::npos) continue;
if (usage_ == "test" && file_name.find("test") == std::string::npos) continue;
if (file_name.find("test") != std::string::npos) {
num_cifar100_records = 10000;
} else if (file_name.find("train") != std::string::npos) {
......@@ -301,15 +323,11 @@ Status CifarOp::ReadCifar100BlockData() {
}
std::ifstream in(file, std::ios::binary);
if (!in.is_open()) {
RETURN_STATUS_UNEXPECTED(file + " can not be opened.");
}
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened.");
for (uint32_t index = 0; index < num_cifar100_records / kCifarBlockImageNum; index++) {
(void)in.read(reinterpret_cast<char *>(&(image_data[0])), block_size * sizeof(unsigned char));
if (in.fail()) {
RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file);
}
CHECK_FAIL_RETURN_UNEXPECTED(!in.fail(), "Fail to read cifar file" + file);
(void)cifar_raw_data_block_->EmplaceBack(image_data);
}
in.close();
......@@ -319,26 +337,20 @@ Status CifarOp::ReadCifar100BlockData() {
}
Status CifarOp::GetCifarFiles() {
// Initialize queue to hold the file names
const std::string kExtension = ".bin";
Path dataset_directory(folder_path_);
auto dirIt = Path::DirIterator::OpenDirectory(&dataset_directory);
Path dir_path(folder_path_);
auto dirIt = Path::DirIterator::OpenDirectory(&dir_path);
if (dirIt) {
while (dirIt->hasNext()) {
Path file = dirIt->next();
std::string filename = file.toString();
if (filename.find(kExtension) != std::string::npos) {
cifar_files_.push_back(filename);
MS_LOG(INFO) << "Cifar operator found file at " << filename << ".";
if (file.Extension() == kExtension) {
cifar_files_.push_back(file.toString());
}
}
} else {
std::string err_msg = "Unable to open directory " + dataset_directory.toString();
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (cifar_files_.size() == 0) {
RETURN_STATUS_UNEXPECTED("No .bin files found under " + folder_path_);
RETURN_STATUS_UNEXPECTED("Unable to open directory " + dir_path.toString());
}
CHECK_FAIL_RETURN_UNEXPECTED(!cifar_files_.empty(), "No .bin files found under " + folder_path_);
std::sort(cifar_files_.begin(), cifar_files_.end());
return Status::OK();
}
......@@ -378,9 +390,8 @@ Status CifarOp::ParseCifarData() {
num_rows_ = cifar_image_label_pairs_.size();
if (num_rows_ == 0) {
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
std::string err_msg = "There is no valid data matching the dataset API " + api +
".Please check file path or dataset API validation first.";
RETURN_STATUS_UNEXPECTED(err_msg);
RETURN_STATUS_UNEXPECTED("There is no valid data matching the dataset API " + api +
".Please check file path or dataset API validation first.");
}
cifar_raw_data_block_->Reset();
return Status::OK();
......@@ -403,46 +414,51 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
return Status::OK();
}
Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) {
Status CifarOp::CountTotalRows(const std::string &dir, const std::string &usage, bool isCIFAR10, int64_t *count) {
// the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block()
std::shared_ptr<CifarOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op));
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).SetUsage(usage).Build(&op));
RETURN_IF_NOT_OK(op->GetCifarFiles());
if (op->cifar_type_ == kCifar10) {
constexpr int64_t num_cifar10_records = 10000;
for (auto &file : op->cifar_files_) {
std::ifstream in(file, std::ios::binary);
if (!in.is_open()) {
std::string err_msg = file + " can not be opened.";
RETURN_STATUS_UNEXPECTED(err_msg);
Path file_path(file);
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file);
std::string file_name = file_path.Basename();
if (op->usage_ == "train") {
if (file_name.find("data_batch") == std::string::npos) continue;
} else if (op->usage_ == "test") {
if (file_name.find("test_batch") == std::string::npos) continue;
} else { // get all the files that contain the word batch, aka any cifar 100 files
if (file_name.find("batch") == std::string::npos) continue;
}
std::ifstream in(file, std::ios::binary);
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened.");
*count = *count + num_cifar10_records;
}
return Status::OK();
} else {
int64_t num_cifar100_records = 0;
for (auto &file : op->cifar_files_) {
size_t pos = file.find_last_of('/');
if (pos == std::string::npos) {
std::string err_msg = "Invalid cifar100 file path";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::string file_name;
if (file.size() > 0)
file_name = file.substr(pos + 1);
else
RETURN_STATUS_UNEXPECTED("Invalid string length!");
Path file_path(file);
std::string file_name = file_path.Basename();
CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file);
if (op->usage_ == "train" && file_path.Basename().find("train") == std::string::npos) continue;
if (op->usage_ == "test" && file_path.Basename().find("test") == std::string::npos) continue;
if (file_name.find("test") != std::string::npos) {
num_cifar100_records = 10000;
num_cifar100_records += 10000;
} else if (file_name.find("train") != std::string::npos) {
num_cifar100_records = 50000;
num_cifar100_records += 50000;
}
std::ifstream in(file, std::ios::binary);
if (!in.is_open()) {
std::string err_msg = file + " can not be opened.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened.");
}
*count = num_cifar100_records;
return Status::OK();
......
......@@ -83,15 +83,23 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Setter method
// @param const std::string & dir
// @return
// @return Builder setter method returns reference to the builder.
Builder &SetCifarDir(const std::string &dir) {
dir_ = dir;
return *this;
}
// Setter method
// @param const std::string &usage
// @return Builder setter method returns reference to the builder.
Builder &SetUsage(const std::string &usage) {
usage_ = usage;
return *this;
}
// Setter method
// @param const std::string & dir
// @return
// @return Builder setter method returns reference to the builder.
Builder &SetCifarType(const bool cifar10) {
if (cifar10) {
cifar_type_ = kCifar10;
......@@ -112,6 +120,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
private:
std::string dir_;
std::string usage_;
int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t op_connect_size_;
......@@ -122,13 +131,15 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Constructor
// @param CifarType type - Cifar10 or Cifar100
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
// @param uint32_t numWorks - Num of workers reading images in parallel
// @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer
// @param std::string - dir directory of cifar dataset
// @param uint32_t - queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf,
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler);
// Destructor.
~CifarOp() = default;
......@@ -153,7 +164,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @param isCIFAR10 true if CIFAR10 and false if CIFAR100
// @param count output arg that will hold the actual dataset size
// @return
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
static Status CountTotalRows(const std::string &dir, const std::string &usage, bool isCIFAR10, int64_t *count);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
......@@ -224,7 +235,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
std::unique_ptr<DataSchema> data_schema_;
int64_t row_cnt_;
int64_t buf_cnt_;
const std::string usage_; // can only be either "train" or "test"
WaitPost wp_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_;
......
......@@ -17,6 +17,7 @@
#include <fstream>
#include <iomanip>
#include <set>
#include "utils/ms_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
......@@ -32,7 +33,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
const int32_t kMnistImageRows = 28;
const int32_t kMnistImageCols = 28;
MnistOp::Builder::Builder() : builder_sampler_(nullptr) {
MnistOp::Builder::Builder() : builder_sampler_(nullptr), builder_usage_("") {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
......@@ -52,22 +53,25 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
*ptr = std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
*ptr = std::make_shared<MnistOp>(builder_usage_, builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_));
return Status::OK();
}
Status MnistOp::Builder::SanityCheck() {
const std::set<std::string> valid = {"test", "train", "all", ""};
Path dir(builder_dir_);
std::string err_msg;
err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : "";
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : "";
err_msg += valid.find(builder_usage_) == valid.end() ? "usage needs to be 'train','test' or 'all'\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size, std::move(sampler)),
usage_(usage),
buf_cnt_(0),
row_cnt_(0),
folder_path_(folder_path),
......@@ -226,9 +230,7 @@ Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) {
uint32_t res = 0;
reader->read(reinterpret_cast<char *>(&res), 4);
if (reader->fail()) {
RETURN_STATUS_UNEXPECTED("Failed to read 4 bytes from file");
}
CHECK_FAIL_RETURN_UNEXPECTED(!reader->fail(), "Failed to read 4 bytes from file");
*result = SwapEndian(res);
return Status::OK();
}
......@@ -239,15 +241,12 @@ uint32_t MnistOp::SwapEndian(uint32_t val) const {
}
Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) {
if (image_reader->is_open() == false) {
RETURN_STATUS_UNEXPECTED("Cannot open mnist image file: " + file_name);
}
CHECK_FAIL_RETURN_UNEXPECTED(image_reader->is_open(), "Cannot open mnist image file: " + file_name);
int64_t image_len = image_reader->seekg(0, std::ios::end).tellg();
(void)image_reader->seekg(0, std::ios::beg);
// The first 16 bytes of the image file are type, number, row and column
if (image_len < 16) {
RETURN_STATUS_UNEXPECTED("Mnist file is corrupted.");
}
CHECK_FAIL_RETURN_UNEXPECTED(image_len >= 16, "Mnist file is corrupted.");
uint32_t magic_number;
RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number));
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber,
......@@ -260,35 +259,25 @@ Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_re
uint32_t cols;
RETURN_IF_NOT_OK(ReadFromReader(image_reader, &cols));
// The image size of the Mnist dataset is fixed at [28,28]
if ((rows != kMnistImageRows) || (cols != kMnistImageCols)) {
RETURN_STATUS_UNEXPECTED("Wrong shape of image.");
}
if ((image_len - 16) != num_items * rows * cols) {
RETURN_STATUS_UNEXPECTED("Wrong number of image.");
}
CHECK_FAIL_RETURN_UNEXPECTED((rows == kMnistImageRows) && (cols == kMnistImageCols), "Wrong shape of image.");
CHECK_FAIL_RETURN_UNEXPECTED((image_len - 16) == num_items * rows * cols, "Wrong number of image.");
*num_images = num_items;
return Status::OK();
}
Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) {
if (label_reader->is_open() == false) {
RETURN_STATUS_UNEXPECTED("Cannot open mnist label file: " + file_name);
}
CHECK_FAIL_RETURN_UNEXPECTED(label_reader->is_open(), "Cannot open mnist label file: " + file_name);
int64_t label_len = label_reader->seekg(0, std::ios::end).tellg();
(void)label_reader->seekg(0, std::ios::beg);
// The first 8 bytes of the image file are type and number
if (label_len < 8) {
RETURN_STATUS_UNEXPECTED("Mnist file is corrupted.");
}
CHECK_FAIL_RETURN_UNEXPECTED(label_len >= 8, "Mnist file is corrupted.");
uint32_t magic_number;
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number));
CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber,
"This is not the mnist label file: " + file_name);
uint32_t num_items;
RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items));
if ((label_len - 8) != num_items) {
RETURN_STATUS_UNEXPECTED("Wrong number of labels!");
}
CHECK_FAIL_RETURN_UNEXPECTED((label_len - 8) == num_items, "Wrong number of labels!");
*num_labels = num_items;
return Status::OK();
}
......@@ -330,6 +319,9 @@ Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *la
}
Status MnistOp::ParseMnistData() {
// MNIST contains 4 files, idx3 are image files, idx 1 are labels
// training files contain 60K examples and testing files contain 10K examples
// t10k-images-idx3-ubyte t10k-labels-idx1-ubyte train-images-idx3-ubyte train-labels-idx1-ubyte
for (size_t i = 0; i < image_names_.size(); ++i) {
std::ifstream image_reader, label_reader;
image_reader.open(image_names_[i], std::ios::binary);
......@@ -354,18 +346,22 @@ Status MnistOp::ParseMnistData() {
Status MnistOp::WalkAllFiles() {
const std::string kImageExtension = "idx3-ubyte";
const std::string kLabelExtension = "idx1-ubyte";
const std::string train_prefix = "train";
const std::string test_prefix = "t10k";
Path dir(folder_path_);
auto dir_it = Path::DirIterator::OpenDirectory(&dir);
std::string prefix; // empty string, used to match usage = "" (default) or usage == "all"
if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix);
if (dir_it != nullptr) {
while (dir_it->hasNext()) {
Path file = dir_it->next();
std::string filename = file.toString();
if (filename.find(kImageExtension) != std::string::npos) {
image_names_.push_back(filename);
std::string filename = file.Basename();
if (filename.find(prefix + "-images-" + kImageExtension) != std::string::npos) {
image_names_.push_back(file.toString());
MS_LOG(INFO) << "Mnist operator found image file at " << filename << ".";
} else if (filename.find(kLabelExtension) != std::string::npos) {
label_names_.push_back(filename);
} else if (filename.find(prefix + "-labels-" + kLabelExtension) != std::string::npos) {
label_names_.push_back(file.toString());
MS_LOG(INFO) << "Mnist Operator found label file at " << filename << ".";
}
}
......@@ -376,9 +372,7 @@ Status MnistOp::WalkAllFiles() {
std::sort(image_names_.begin(), image_names_.end());
std::sort(label_names_.begin(), label_names_.end());
if (image_names_.size() != label_names_.size()) {
RETURN_STATUS_UNEXPECTED("num of images does not equal to num of labels");
}
CHECK_FAIL_RETURN_UNEXPECTED(image_names_.size() == label_names_.size(), "num of idx3 files != num of idx1 files");
return Status::OK();
}
......@@ -397,11 +391,11 @@ Status MnistOp::LaunchThreadsAndInitOp() {
return Status::OK();
}
Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
Status MnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
std::shared_ptr<MnistOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op));
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetUsage(usage).Build(&op));
RETURN_IF_NOT_OK(op->WalkAllFiles());
......
......@@ -47,8 +47,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
class Builder {
public:
// Constructor for Builder class of MnistOp
// @param uint32_t numWrks - number of parallel workers
// @param dir - directory folder got ImageNetFolder
Builder();
// Destructor.
......@@ -87,13 +85,20 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
}
// Setter method
// @param const std::string & dir
// @param const std::string &dir
// @return
Builder &SetDir(const std::string &dir) {
builder_dir_ = dir;
return *this;
}
// Setter method
// @param const std::string &usage
// @return
Builder &SetUsage(const std::string &usage) {
builder_usage_ = usage;
return *this;
}
// Check validity of input args
// @return - The error code return
Status SanityCheck();
......@@ -105,6 +110,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
private:
std::string builder_dir_;
std::string builder_usage_;
int32_t builder_num_workers_;
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
......@@ -113,14 +119,15 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
};
// Constructor
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
// @param int32_t num_workers - number of workers reading images in parallel
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
// @param std::string folder_path - dir directory of mnist
// @param int32_t queue_size - connector queue size
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path,
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
// Destructor.
~MnistOp() = default;
......@@ -150,7 +157,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param dir path to the MNIST directory
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
// @return
static Status CountTotalRows(const std::string &dir, int64_t *count);
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
......@@ -241,6 +248,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
WaitPost wp_;
std::string folder_path_; // directory of image folder
int32_t rows_per_buffer_;
const std::string usage_; // can only be either "train" or "test"
std::unique_ptr<DataSchema> data_schema_;
std::vector<MnistLabelPair> image_label_pairs_;
std::vector<std::string> image_names_;
......
......@@ -18,14 +18,15 @@
#include <algorithm>
#include <fstream>
#include <iomanip>
#include "./tinyxml2.h"
#include "utils/ms_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "utils/ms_utils.h"
using tinyxml2::XMLDocument;
using tinyxml2::XMLElement;
......@@ -81,7 +82,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
}
*ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_,
*ptr = std::make_shared<VOCOp>(builder_task_type_, builder_usage_, builder_dir_, builder_labels_to_read_,
builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_,
builder_decode_, std::move(builder_schema_), std::move(builder_sampler_));
return Status::OK();
......@@ -103,7 +104,7 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
row_cnt_(0),
buf_cnt_(0),
task_type_(task_type),
task_mode_(task_mode),
usage_(task_mode),
folder_path_(folder_path),
class_index_(class_index),
rows_per_buffer_(rows_per_buffer),
......@@ -251,10 +252,9 @@ Status VOCOp::WorkerEntry(int32_t worker_id) {
Status VOCOp::ParseImageIds() {
std::string image_sets_file;
if (task_type_ == TaskType::Segmentation) {
image_sets_file =
folder_path_ + std::string(kImageSetsSegmentation) + task_mode_ + std::string(kImageSetsExtension);
image_sets_file = folder_path_ + std::string(kImageSetsSegmentation) + usage_ + std::string(kImageSetsExtension);
} else if (task_type_ == TaskType::Detection) {
image_sets_file = folder_path_ + std::string(kImageSetsMain) + task_mode_ + std::string(kImageSetsExtension);
image_sets_file = folder_path_ + std::string(kImageSetsMain) + usage_ + std::string(kImageSetsExtension);
}
std::ifstream in_file;
in_file.open(image_sets_file);
......@@ -431,13 +431,13 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op));
Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
*count = static_cast<int64_t>(op->image_ids_.size());
} else if (task_type == "Segmentation") {
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op));
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
*count = static_cast<int64_t>(op->image_ids_.size());
}
......@@ -458,7 +458,7 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t
} else {
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op));
Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
for (const auto label : op->label_index_) {
......
......@@ -73,7 +73,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
}
// Setter method.
// @param const std::string & task_type
// @param const std::string &task_type
// @return Builder setter method returns reference to the builder.
Builder &SetTask(const std::string &task_type) {
if (task_type == "Segmentation") {
......@@ -85,10 +85,10 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
}
// Setter method.
// @param const std::string & task_mode
// @param const std::string &usage
// @return Builder setter method returns reference to the builder.
Builder &SetMode(const std::string &task_mode) {
builder_task_mode_ = task_mode;
Builder &SetUsage(const std::string &usage) {
builder_usage_ = usage;
return *this;
}
......@@ -145,7 +145,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
bool builder_decode_;
std::string builder_dir_;
TaskType builder_task_type_;
std::string builder_task_mode_;
std::string builder_usage_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int32_t builder_rows_per_buffer_;
......@@ -279,7 +279,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
int64_t buf_cnt_;
std::string folder_path_;
TaskType task_type_;
std::string task_mode_;
std::string usage_;
int32_t rows_per_buffer_;
std::unique_ptr<DataSchema> data_schema_;
......
......@@ -111,34 +111,36 @@ std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::s
/// \brief Function to create a CelebADataset
/// \notes The generated dataset has two columns ['image', 'attr'].
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
/// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'.
/// \param[in] usage One of "all", "train", "valid" or "test".
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \param[in] decode Decode the images after reading (default=false).
/// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
/// \return Shared pointer to the current Dataset
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false,
const std::set<std::string> &extensions = {});
/// \brief Function to create a Cifar10 Dataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \notes The generated dataset has two columns ["image", "label"]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] usage of CIFAR10, can be "train", "test" or "all"
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage = std::string(),
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
/// \brief Function to create a Cifar100 Dataset
/// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
/// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] usage of CIFAR100, can be "train", "test" or "all"
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage = std::string(),
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
/// \brief Function to create a CLUEDataset
......@@ -212,7 +214,7 @@ std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, c
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
/// All images within one folder have the same label
/// The generated dataset has two columns ['image', 'label']
/// The generated dataset has two columns ["image", "label"]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] decode A flag to decode in ImageFolder
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
......@@ -227,7 +229,7 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir,
#ifndef ENABLE_ANDROID
/// \brief Function to create a ManifestDataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \notes The generated dataset has two columns ["image", "label"]
/// \param[in] dataset_file The dataset file to be read
/// \param[in] usage Need "train", "eval" or "inference" data (default="train")
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
......@@ -243,12 +245,13 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const
#endif
/// \brief Function to create a MnistDataset
/// \notes The generated dataset has two columns ['image', 'label']
/// \notes The generated dataset has two columns ["image", "label"]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] usage of MNIST, can be "train", "test" or "all"
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current MnistDataset
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = std::string(),
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
/// \brief Function to create a ConcatDataset
......@@ -404,14 +407,14 @@ std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &datase
/// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection"
/// \param[in] mode Set the data list txt file to be readed
/// \param[in] usage The type of data list text file to be read
/// \param[in] class_indexing A str-to-int mapping from label name to index, only valid in "Detection" task
/// \param[in] decode Decode the images after reading
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current Dataset
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
const std::string &mode = "train",
const std::string &usage = "train",
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
const std::shared_ptr<SamplerObj> &sampler = RandomSampler());
#endif
......@@ -702,9 +705,8 @@ class AlbumDataset : public Dataset {
class CelebADataset : public Dataset {
public:
/// \brief Constructor
CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
const std::set<std::string> &extensions);
CelebADataset(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
const bool &decode, const std::set<std::string> &extensions);
/// \brief Destructor
~CelebADataset() = default;
......@@ -719,7 +721,7 @@ class CelebADataset : public Dataset {
private:
std::string dataset_dir_;
std::string dataset_type_;
std::string usage_;
bool decode_;
std::set<std::string> extensions_;
std::shared_ptr<SamplerObj> sampler_;
......@@ -730,7 +732,7 @@ class CelebADataset : public Dataset {
class Cifar10Dataset : public Dataset {
public:
/// \brief Constructor
Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
Cifar10Dataset(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
~Cifar10Dataset() = default;
......@@ -745,13 +747,14 @@ class Cifar10Dataset : public Dataset {
private:
std::string dataset_dir_;
std::string usage_;
std::shared_ptr<SamplerObj> sampler_;
};
class Cifar100Dataset : public Dataset {
public:
/// \brief Constructor
Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
Cifar100Dataset(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
~Cifar100Dataset() = default;
......@@ -766,6 +769,7 @@ class Cifar100Dataset : public Dataset {
private:
std::string dataset_dir_;
std::string usage_;
std::shared_ptr<SamplerObj> sampler_;
};
......@@ -831,7 +835,7 @@ class CocoDataset : public Dataset {
enum CsvType : uint8_t { INT = 0, FLOAT, STRING };
/// \brief Base class of CSV Record
struct CsvBase {
class CsvBase {
public:
CsvBase() = default;
explicit CsvBase(CsvType t) : type(t) {}
......@@ -936,7 +940,7 @@ class ManifestDataset : public Dataset {
class MnistDataset : public Dataset {
public:
/// \brief Constructor
MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler);
MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
~MnistDataset() = default;
......@@ -951,6 +955,7 @@ class MnistDataset : public Dataset {
private:
std::string dataset_dir_;
std::string usage_;
std::shared_ptr<SamplerObj> sampler_;
};
......@@ -1087,7 +1092,7 @@ class TFRecordDataset : public Dataset {
class VOCDataset : public Dataset {
public:
/// \brief Constructor
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage,
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
......@@ -1110,7 +1115,7 @@ class VOCDataset : public Dataset {
const std::string kColumnTruncate = "truncate";
std::string dataset_dir_;
std::string task_;
std::string mode_;
std::string usage_;
std::map<std::string, int32_t> class_index_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
......
......@@ -132,6 +132,12 @@ def check_valid_detype(type_):
return True
def check_valid_str(value, valid_strings, arg_name=""):
type_check(value, (str,), arg_name)
if value not in valid_strings:
raise ValueError("Input {0} is not within the valid set of {1}.".format(arg_name, str(valid_strings)))
def check_columns(columns, name):
"""
Validate strings in column_names.
......
......@@ -2877,6 +2877,9 @@ class MnistDataset(MappableDataset):
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 60,000
train samples, "test" will read from 10,000 test samples, "all" will read from all 70,000 samples.
(default=None, all samples)
num_samples (int, optional): The number of images to be included in the dataset
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
......@@ -2906,11 +2909,12 @@ class MnistDataset(MappableDataset):
"""
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.usage = usage
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.shuffle_level = shuffle
......@@ -2920,6 +2924,7 @@ class MnistDataset(MappableDataset):
def get_args(self):
args = super().get_args()
args["dataset_dir"] = self.dataset_dir
args["usage"] = self.usage
args["num_samples"] = self.num_samples
args["shuffle"] = self.shuffle_level
args["sampler"] = self.sampler
......@@ -2935,7 +2940,7 @@ class MnistDataset(MappableDataset):
Number, number of batches.
"""
if self.dataset_size is None:
num_rows = MnistOp.get_num_rows(self.dataset_dir)
num_rows = MnistOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
......@@ -3913,6 +3918,9 @@ class Cifar10Dataset(MappableDataset):
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000
train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples.
(default=None, all samples)
num_samples (int, optional): The number of images to be included in the dataset.
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
......@@ -3946,11 +3954,12 @@ class Cifar10Dataset(MappableDataset):
"""
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.usage = usage
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.num_shards = num_shards
......@@ -3960,6 +3969,7 @@ class Cifar10Dataset(MappableDataset):
def get_args(self):
args = super().get_args()
args["dataset_dir"] = self.dataset_dir
args["usage"] = self.usage
args["num_samples"] = self.num_samples
args["sampler"] = self.sampler
args["num_shards"] = self.num_shards
......@@ -3975,7 +3985,7 @@ class Cifar10Dataset(MappableDataset):
Number, number of batches.
"""
if self.dataset_size is None:
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
num_rows = CifarOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage, True)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
......@@ -4051,6 +4061,9 @@ class Cifar100Dataset(MappableDataset):
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000
train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples.
(default=None, all samples)
num_samples (int, optional): The number of images to be included in the dataset.
(default=None, all images).
num_parallel_workers (int, optional): Number of workers to read the data
......@@ -4082,11 +4095,12 @@ class Cifar100Dataset(MappableDataset):
"""
@check_mnist_cifar_dataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
shuffle=None, sampler=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.usage = usage
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.num_shards = num_shards
......@@ -4096,6 +4110,7 @@ class Cifar100Dataset(MappableDataset):
def get_args(self):
args = super().get_args()
args["dataset_dir"] = self.dataset_dir
args["usage"] = self.usage
args["num_samples"] = self.num_samples
args["sampler"] = self.sampler
args["num_shards"] = self.num_shards
......@@ -4111,7 +4126,7 @@ class Cifar100Dataset(MappableDataset):
Number, number of batches.
"""
if self.dataset_size is None:
num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
num_rows = CifarOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage, False)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
......@@ -4467,7 +4482,7 @@ class VOCDataset(MappableDataset):
dataset_dir (str): Path to the root directory that contains the dataset.
task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection"
(default="Segmentation").
mode (str): Set the data list txt file to be readed (default="train").
usage (str): The type of data list text file to be read (default="train").
class_indexing (dict, optional): A str-to-int mapping from label name to index, only valid in
"Detection" task (default=None, the folder names will be sorted alphabetically and each
class will be given a unique index starting from 0).
......@@ -4502,24 +4517,24 @@ class VOCDataset(MappableDataset):
>>> import mindspore.dataset as ds
>>> dataset_dir = "/path/to/voc_dataset_directory"
>>> # 1) read VOC data for segmenatation train
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", mode="train")
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", usage="train")
>>> # 2) read VOC data for detection train
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train")
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train")
>>> # 3) read all VOC dataset samples in dataset_dir with 8 threads in random order:
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", num_parallel_workers=8)
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train", num_parallel_workers=8)
>>> # 4) read then decode all VOC dataset samples in dataset_dir in sequence:
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", decode=True, shuffle=False)
>>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train", decode=True, shuffle=False)
>>> # in VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
>>> # in VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
"""
@check_vocdataset
def __init__(self, dataset_dir, task="Segmentation", mode="train", class_indexing=None, num_samples=None,
def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None,
num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.task = task
self.mode = mode
self.usage = usage
self.class_indexing = class_indexing
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
......@@ -4532,7 +4547,7 @@ class VOCDataset(MappableDataset):
args = super().get_args()
args["dataset_dir"] = self.dataset_dir
args["task"] = self.task
args["mode"] = self.mode
args["usage"] = self.usage
args["class_indexing"] = self.class_indexing
args["num_samples"] = self.num_samples
args["sampler"] = self.sampler
......@@ -4560,7 +4575,7 @@ class VOCDataset(MappableDataset):
else:
class_indexing = self.class_indexing
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.usage, class_indexing, num_samples)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
......@@ -4584,7 +4599,7 @@ class VOCDataset(MappableDataset):
else:
class_indexing = self.class_indexing
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing)
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.usage, class_indexing)
def is_shuffled(self):
if self.shuffle_level is None:
......@@ -4824,7 +4839,7 @@ class CelebADataset(MappableDataset):
dataset_dir (str): Path to the root directory that contains the dataset.
num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config).
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None).
dataset_type (str): one of 'all', 'train', 'valid' or 'test'.
usage (str): one of 'all', 'train', 'valid' or 'test'.
sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
decode (bool, optional): decode the images after reading (default=False).
extensions (list[str], optional): List of file extensions to be
......@@ -4838,8 +4853,8 @@ class CelebADataset(MappableDataset):
"""
@check_celebadataset
def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, dataset_type='all',
sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None):
def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False,
extensions=None, num_samples=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
......@@ -4847,7 +4862,7 @@ class CelebADataset(MappableDataset):
self.decode = decode
self.extensions = extensions
self.num_samples = num_samples
self.dataset_type = dataset_type
self.usage = usage
self.num_shards = num_shards
self.shard_id = shard_id
self.shuffle_level = shuffle
......@@ -4860,7 +4875,7 @@ class CelebADataset(MappableDataset):
args["decode"] = self.decode
args["extensions"] = self.extensions
args["num_samples"] = self.num_samples
args["dataset_type"] = self.dataset_type
args["usage"] = self.usage
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args
......
......@@ -273,7 +273,7 @@ def create_node(node):
elif dataset_op == 'MnistDataset':
sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'MindDataset':
......@@ -296,12 +296,12 @@ def create_node(node):
elif dataset_op == 'Cifar10Dataset':
sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'Cifar100Dataset':
sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'VOCDataset':
......
......@@ -27,7 +27,7 @@ from mindspore.dataset.callback import DSCallback
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
check_columns, check_pos_int32
check_columns, check_pos_int32, check_valid_str
from . import datasets
from . import samplers
......@@ -74,6 +74,10 @@ def check_mnist_cifar_dataset(method):
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
......@@ -154,15 +158,15 @@ def check_vocdataset(method):
task = param_dict.get('task')
type_check(task, (str,), "task")
mode = param_dict.get('mode')
type_check(mode, (str,), "mode")
usage = param_dict.get('usage')
type_check(usage, (str,), "usage")
if task == "Segmentation":
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt")
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
if param_dict.get('class_indexing') is not None:
raise ValueError("class_indexing is invalid in Segmentation task")
elif task == "Detection":
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", mode + ".txt")
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
else:
raise ValueError("Invalid task : " + task)
......@@ -235,9 +239,9 @@ def check_celebadataset(method):
validate_dataset_param_value(nreq_param_list, param_dict, list)
validate_dataset_param_value(nreq_param_str, param_dict, str)
dataset_type = param_dict.get('dataset_type')
if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'):
raise ValueError("dataset_type should be one of 'all', 'train', 'valid' or 'test'.")
usage = param_dict.get('usage')
if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
raise ValueError("usage should be one of 'all', 'train', 'valid' or 'test'.")
check_sampler_shuffle_shard_options(param_dict)
......
......@@ -5,7 +5,7 @@ SET(DE_UT_SRCS
common/cvop_common.cc
common/bboxop_common.cc
auto_contrast_op_test.cc
album_op_test.cc
album_op_test.cc
batch_op_test.cc
bit_functions_test.cc
storage_container_test.cc
......@@ -62,8 +62,8 @@ SET(DE_UT_SRCS
rescale_op_test.cc
resize_op_test.cc
resize_with_bbox_op_test.cc
rgba_to_bgr_op_test.cc
rgba_to_rgb_op_test.cc
rgba_to_bgr_op_test.cc
rgba_to_rgb_op_test.cc
schema_test.cc
skip_op_test.cc
shuffle_op_test.cc
......
......@@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
......@@ -45,10 +45,10 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 10);
......@@ -62,7 +62,7 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
// Create a Cifar100 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
std::shared_ptr<Dataset> ds = Cifar100(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
......@@ -96,7 +96,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100DatasetFail1.";
// Create a Cifar100 Dataset
std::shared_ptr<Dataset> ds = Cifar100("", RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar100("", std::string(), RandomSampler(false, 10));
EXPECT_EQ(ds, nullptr);
}
......@@ -104,7 +104,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetFail1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10DatasetFail1.";
// Create a Cifar10 Dataset
std::shared_ptr<Dataset> ds = Cifar10("", RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10("", std::string(), RandomSampler(false, 10));
EXPECT_EQ(ds, nullptr);
}
......@@ -113,7 +113,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetWithNullSampler) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, nullptr);
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), nullptr);
// Expect failure: sampler can not be nullptr
EXPECT_EQ(ds, nullptr);
}
......@@ -123,7 +123,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithNullSampler) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
std::shared_ptr<Dataset> ds = Cifar100(folder_path, nullptr);
std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), nullptr);
// Expect failure: sampler can not be nullptr
EXPECT_EQ(ds, nullptr);
}
......@@ -133,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithWrongSampler) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
std::shared_ptr<Dataset> ds = Cifar100(folder_path, RandomSampler(false, -10));
std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), RandomSampler(false, -10));
// Expect failure: sampler is not construnced correctly
EXPECT_EQ(ds, nullptr);
}
......@@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestIteratorEmptyColumn) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorEmptyColumn.";
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 5));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 5));
EXPECT_NE(ds, nullptr);
// Create a Rename operation on ds
......@@ -64,7 +64,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn.";
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 4));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 4));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -103,7 +103,7 @@ TEST_F(MindDataTestPipeline, TestIteratorReOrder) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorReOrder.";
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, SequentialSampler(false, 4));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), SequentialSampler(false, 4));
EXPECT_NE(ds, nullptr);
// Create a Take operation on ds
......@@ -160,9 +160,8 @@ TEST_F(MindDataTestPipeline, TestIteratorTwoColumns) {
// Iterate the dataset and get each row
std::vector<std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
std::vector<TensorShape> expect = {TensorShape({173673}), TensorShape({1, 4}),
TensorShape({173673}), TensorShape({1, 4}),
TensorShape({147025}), TensorShape({1, 4}),
std::vector<TensorShape> expect = {TensorShape({173673}), TensorShape({1, 4}), TensorShape({173673}),
TensorShape({1, 4}), TensorShape({147025}), TensorShape({1, 4}),
TensorShape({211653}), TensorShape({1, 4})};
uint64_t i = 0;
......@@ -187,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn.";
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 4));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 4));
EXPECT_NE(ds, nullptr);
// Pass wrong column name
......
......@@ -40,7 +40,7 @@ TEST_F(MindDataTestPipeline, TestBatchAndRepeat) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
......@@ -82,7 +82,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess1) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
......@@ -118,13 +118,12 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
std::map<std::string, std::pair<mindspore::dataset::TensorShape, std::shared_ptr<Tensor>>> pad_info;
ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, 2, 3},
&BucketBatchTestFunction, pad_info, true, true);
ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, 2, 3}, &BucketBatchTestFunction, pad_info, true, true);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
......@@ -157,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail1) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
......@@ -172,7 +171,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail2) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
......@@ -187,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail3) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
......@@ -202,7 +201,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail4) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
......@@ -217,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail5) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
......@@ -232,7 +231,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail6) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, -2, 3});
......@@ -246,7 +245,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail7) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a BucketBatchByLength operation on ds
......@@ -313,7 +312,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess) {
// Create a Cifar10 Dataset
// Column names: {"image", "label"}
folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 9));
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 9));
EXPECT_NE(ds2, nullptr);
// Create a Project operation on ds
......@@ -365,7 +364,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess2) {
// Create a Cifar10 Dataset
// Column names: {"image", "label"}
folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 9));
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 9));
EXPECT_NE(ds2, nullptr);
// Create a Project operation on ds
......@@ -704,11 +703,11 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) {
}
TEST_F(MindDataTestPipeline, TestRepeatDefault) {
MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatDefault.";
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatDefault.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
......@@ -723,21 +722,21 @@ TEST_F(MindDataTestPipeline, TestRepeatDefault) {
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr <Iterator> iter = ds->CreateIterator();
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// iterate over the dataset and get each row
std::unordered_map <std::string, std::shared_ptr<Tensor>> row;
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size()!= 0) {
while (row.size() != 0) {
// manually stop
if (i == 100) {
break;
}
i++;
auto image = row["image"];
MS_LOG(INFO)<< "Tensor image shape: " << image->shape();
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
......@@ -747,11 +746,11 @@ TEST_F(MindDataTestPipeline, TestRepeatDefault) {
}
TEST_F(MindDataTestPipeline, TestRepeatOne) {
MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatOne.";
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatOne.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
......@@ -766,17 +765,17 @@ TEST_F(MindDataTestPipeline, TestRepeatOne) {
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr <Iterator> iter = ds->CreateIterator();
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// iterate over the dataset and get each row
std::unordered_map <std::string, std::shared_ptr<Tensor>> row;
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size()!= 0) {
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO)<< "Tensor image shape: " << image->shape();
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
......@@ -1013,7 +1012,7 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 20));
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 20));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
......@@ -1060,7 +1059,6 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestZipFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipFail.";
// We expect this test to fail because we are the both datasets we are zipping have "image" and "label" columns
......@@ -1128,7 +1126,7 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) {
EXPECT_NE(ds1, nullptr);
folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds2, nullptr);
// Create a Project operation on ds
......
......@@ -43,10 +43,11 @@ TEST_F(MindDataTestPipeline, TestCelebADataset) {
// Check if CelebAOp read correct images/attr
std::string expect_file[] = {"1.JPEG", "2.jpg"};
std::vector<std::vector<uint32_t>> expect_attr_vector =
{{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0,
1, 0, 0, 1}, {0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 1}};
std::vector<std::vector<uint32_t>> expect_attr_vector = {
{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1},
{0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1}};
uint64_t i = 0;
while (row.size() != 0) {
auto image = row["image"];
......@@ -132,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithWrongDatasetDir) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistFailWithWrongDatasetDir.";
// Create a Mnist Dataset
std::shared_ptr<Dataset> ds = Mnist("", RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Mnist("", std::string(), RandomSampler(false, 10));
EXPECT_EQ(ds, nullptr);
}
......@@ -141,7 +142,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithNullSampler) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, nullptr);
std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), nullptr);
// Expect failure: sampler can not be nullptr
EXPECT_EQ(ds, nullptr);
}
......
......@@ -30,7 +30,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
int number_of_classes = 10;
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
......@@ -38,7 +38,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
EXPECT_NE(hwc_to_chw, nullptr);
// Create a Map operation on ds
ds = ds->Map({hwc_to_chw},{"image"});
ds = ds->Map({hwc_to_chw}, {"image"});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -51,10 +51,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0);
std::shared_ptr<TensorOperation> cutmix_batch_op =
vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0);
EXPECT_NE(cutmix_batch_op, nullptr);
// Create a Map operation on ds
......@@ -77,10 +78,12 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
auto label = row["label"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
MS_LOG(INFO) << "Label shape: " << label->shape();
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1]
&& 32 == image->shape()[2] && 32 == image->shape()[3], true);
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1] &&
32 == image->shape()[2] && 32 == image->shape()[3],
true);
EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] &&
number_of_classes == label->shape()[1], true);
number_of_classes == label->shape()[1],
true);
iter->GetNextRow(&row);
}
......@@ -95,7 +98,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
int number_of_classes = 10;
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -108,7 +111,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC);
......@@ -134,10 +137,12 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) {
auto label = row["label"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
MS_LOG(INFO) << "Label shape: " << label->shape();
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1]
&& 32 == image->shape()[2] && 3 == image->shape()[3], true);
EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1] &&
32 == image->shape()[2] && 3 == image->shape()[3],
true);
EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] &&
number_of_classes == label->shape()[1], true);
number_of_classes == label->shape()[1],
true);
iter->GetNextRow(&row);
}
......@@ -151,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) {
// Must fail because alpha can't be negative
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -164,10 +169,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) {
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5);
std::shared_ptr<TensorOperation> cutmix_batch_op =
vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5);
EXPECT_EQ(cutmix_batch_op, nullptr);
}
......@@ -175,7 +181,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
// Must fail because prob can't be negative
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -188,20 +194,19 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) {
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC,
1, -0.5);
std::shared_ptr<TensorOperation> cutmix_batch_op =
vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5);
EXPECT_EQ(cutmix_batch_op, nullptr);
}
TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) {
// Must fail because alpha can't be zero
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -214,11 +219,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) {
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
ds = ds->Map({one_hot_op}, {"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC,
0.0, 0.5);
std::shared_ptr<TensorOperation> cutmix_batch_op =
vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 0.0, 0.5);
EXPECT_EQ(cutmix_batch_op, nullptr);
}
......@@ -371,7 +376,7 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) {
TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -395,7 +400,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) {
// This should fail because alpha can't be zero
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -418,7 +423,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) {
TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -467,7 +472,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
......@@ -871,8 +876,7 @@ TEST_F(MindDataTestPipeline, TestRandomPosterizeSuccess1) {
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> posterize =
vision::RandomPosterize({1, 4});
std::shared_ptr<TensorOperation> posterize = vision::RandomPosterize({1, 4});
EXPECT_NE(posterize, nullptr);
// Create a Map operation on ds
......@@ -1114,7 +1118,7 @@ TEST_F(MindDataTestPipeline, TestRandomRotation) {
TEST_F(MindDataTestPipeline, TestUniformAugWithOps) {
// Create a Mnist Dataset
std::string folder_path = datasets_root_path_ + "/testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 20));
std::shared_ptr<Dataset> ds = Mnist(folder_path, "", RandomSampler(false, 20));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
......
......@@ -42,9 +42,13 @@ std::shared_ptr<CelebAOp> Celeba(int32_t num_workers, int32_t rows_per_buffer, i
bool decode = false, const std::string &dataset_type="all") {
std::shared_ptr<CelebAOp> so;
CelebAOp::Builder builder;
Status rc = builder.SetNumWorkers(num_workers).SetCelebADir(dir).SetRowsPerBuffer(rows_per_buffer)
.SetOpConnectorSize(queue_size).SetSampler(std::move(sampler)).SetDecode(decode)
.SetDatasetType(dataset_type).Build(&so);
Status rc = builder.SetNumWorkers(num_workers)
.SetCelebADir(dir)
.SetRowsPerBuffer(rows_per_buffer)
.SetOpConnectorSize(queue_size)
.SetSampler(std::move(sampler))
.SetDecode(decode)
.SetUsage(dataset_type).Build(&so);
return so;
}
......
......@@ -63,9 +63,7 @@ TEST_F(MindDataTestVOCOp, TestVOCDetection) {
std::string task_mode("train");
std::shared_ptr<VOCOp> my_voc_op;
VOCOp::Builder builder;
Status rc = builder.SetDir(dataset_path)
.SetTask(task_type)
.SetMode(task_mode)
Status rc = builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode)
.Build(&my_voc_op);
ASSERT_TRUE(rc.IsOk());
......@@ -116,9 +114,7 @@ TEST_F(MindDataTestVOCOp, TestVOCSegmentation) {
std::string task_mode("train");
std::shared_ptr<VOCOp> my_voc_op;
VOCOp::Builder builder;
Status rc = builder.SetDir(dataset_path)
.SetTask(task_type)
.SetMode(task_mode)
Status rc = builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode)
.Build(&my_voc_op);
ASSERT_TRUE(rc.IsOk());
......@@ -173,9 +169,8 @@ TEST_F(MindDataTestVOCOp, TestVOCClassIndex) {
class_index["train"] = 5;
std::shared_ptr<VOCOp> my_voc_op;
VOCOp::Builder builder;
Status rc = builder.SetDir(dataset_path)
.SetTask(task_type)
.SetMode(task_mode)
Status rc =
builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode)
.SetClassIndex(class_index)
.Build(&my_voc_op);
ASSERT_TRUE(rc.IsOk());
......
......@@ -42,8 +42,8 @@ def test_bounding_box_augment_with_rotation_op(plot_vis=False):
original_seed = config_get_set_seed(0)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
# Ratio is set to 1 to apply rotation on all bounding boxes.
test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1)
......@@ -81,8 +81,8 @@ def test_bounding_box_augment_with_crop_op(plot_vis=False):
original_seed = config_get_set_seed(0)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
# Ratio is set to 0.9 to apply RandomCrop of size (50, 50) on 90% of the bounding boxes.
test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9)
......@@ -120,8 +120,8 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False):
original_seed = config_get_set_seed(1)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9)
......@@ -188,8 +188,8 @@ def test_bounding_box_augment_valid_edge_c(plot_vis=False):
original_seed = config_get_set_seed(1)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1)
......@@ -232,7 +232,7 @@ def test_bounding_box_augment_invalid_ratio_c():
"""
logger.info("test_bounding_box_augment_invalid_ratio_c")
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
try:
# ratio range is from 0 - 1
......@@ -256,13 +256,13 @@ def test_bounding_box_augment_invalid_bounds_c():
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1),
1)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features")
......
......@@ -20,7 +20,7 @@ DATA_DIR = "../data/dataset/testCelebAData/"
def test_celeba_dataset_label():
data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False)
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
expect_labels = [
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
0, 0, 1],
......@@ -85,11 +85,13 @@ def test_celeba_dataset_distribute():
count = count + 1
assert count == 1
def test_celeba_get_dataset_size():
data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False)
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
size = data.get_dataset_size()
assert size == 2
if __name__ == '__main__':
test_celeba_dataset_label()
test_celeba_dataset_op()
......
......@@ -392,6 +392,59 @@ def test_cifar100_visualize(plot=False):
visualize_dataset(image_list, label_list)
def test_cifar_usage():
"""
test usage of cifar
"""
logger.info("Test Cifar100Dataset usage flag")
# flag, if True, test cifar10 else test cifar100
def test_config(usage, flag=True, cifar_path=None):
if cifar_path is None:
cifar_path = DATA_DIR_10 if flag else DATA_DIR_100
try:
data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
# test the usage of CIFAR100
assert test_config("train") == 10000
assert test_config("all") == 10000
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
assert "Argument usage with value ['list'] is not of type (<class 'str'>,)" in test_config(["list"])
assert "no valid data matching the dataset API Cifar10Dataset" in test_config("test")
# test the usage of CIFAR10
assert test_config("test", False) == 10000
assert test_config("all", False) == 10000
assert "no valid data matching the dataset API Cifar100Dataset" in test_config("train", False)
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid", False)
# change this directory to the folder that contains all cifar10 files
all_cifar10 = None
if all_cifar10 is not None:
assert test_config("train", True, all_cifar10) == 50000
assert test_config("test", True, all_cifar10) == 10000
assert test_config("all", True, all_cifar10) == 60000
assert ds.Cifar10Dataset(all_cifar10, usage="train").get_dataset_size() == 50000
assert ds.Cifar10Dataset(all_cifar10, usage="test").get_dataset_size() == 10000
assert ds.Cifar10Dataset(all_cifar10, usage="all").get_dataset_size() == 60000
# change this directory to the folder that contains all cifar100 files
all_cifar100 = None
if all_cifar100 is not None:
assert test_config("train", False, all_cifar100) == 50000
assert test_config("test", False, all_cifar100) == 10000
assert test_config("all", False, all_cifar100) == 60000
assert ds.Cifar100Dataset(all_cifar100, usage="train").get_dataset_size() == 50000
assert ds.Cifar100Dataset(all_cifar100, usage="test").get_dataset_size() == 10000
assert ds.Cifar100Dataset(all_cifar100, usage="all").get_dataset_size() == 60000
if __name__ == '__main__':
test_cifar10_content_check()
test_cifar10_basic()
......@@ -405,3 +458,5 @@ if __name__ == '__main__':
test_cifar100_pk_sampler()
test_cifar100_exception()
test_cifar100_visualize(plot=False)
test_cifar_usage()
......@@ -58,6 +58,14 @@ def test_mnist_dataset_size():
ds_total = ds.MnistDataset(MNIST_DATA_DIR)
assert ds_total.get_dataset_size() == 10000
# test get dataset_size with the usage arg
test_size = ds.MnistDataset(MNIST_DATA_DIR, usage="test").get_dataset_size()
assert test_size == 10000
train_size = ds.MnistDataset(MNIST_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 0
all_size = ds.MnistDataset(MNIST_DATA_DIR, usage="all").get_dataset_size()
assert all_size == 10000
ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0)
assert ds_shard_1_0.get_dataset_size() == 10000
......@@ -86,6 +94,14 @@ def test_cifar10_dataset_size():
ds_total = ds.Cifar10Dataset(CIFAR10_DATA_DIR)
assert ds_total.get_dataset_size() == 10000
# test get_dataset_size with usage flag
train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 10000
test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size()
assert test_size == 0
all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size()
assert all_size == 10000
ds_shard_1_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=1, shard_id=0)
assert ds_shard_1_0.get_dataset_size() == 10000
......@@ -103,6 +119,14 @@ def test_cifar100_dataset_size():
ds_total = ds.Cifar100Dataset(CIFAR100_DATA_DIR)
assert ds_total.get_dataset_size() == 10000
# test get_dataset_size with usage flag
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 0
test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size()
assert test_size == 10000
all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size()
assert all_size == 10000
ds_shard_1_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=1, shard_id=0)
assert ds_shard_1_0.get_dataset_size() == 10000
......@@ -111,3 +135,12 @@ def test_cifar100_dataset_size():
ds_shard_3_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=3, shard_id=0)
assert ds_shard_3_0.get_dataset_size() == 3334
if __name__ == '__main__':
test_imagenet_rawdata_dataset_size()
test_imagenet_tf_file_dataset_size()
test_mnist_dataset_size()
test_manifest_dataset_size()
test_cifar10_dataset_size()
test_cifar100_dataset_size()
......@@ -229,6 +229,41 @@ def test_mnist_visualize(plot=False):
visualize_dataset(image_list, label_list)
def test_mnist_usage():
"""
Validate MnistDataset image readings
"""
logger.info("Test MnistDataset usage flag")
def test_config(usage, mnist_path=None):
mnist_path = DATA_DIR if mnist_path is None else mnist_path
try:
data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator():
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("test") == 10000
assert test_config("all") == 10000
assert " no valid data matching the dataset API MnistDataset" in test_config("train")
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
assert "Argument usage with value ['list'] is not of type (<class 'str'>,)" in test_config(["list"])
# change this directory to the folder that contains all mnist files
all_files_path = None
# the following tests on the entire datasets
if all_files_path is not None:
assert test_config("train", all_files_path) == 60000
assert test_config("test", all_files_path) == 10000
assert test_config("all", all_files_path) == 70000
assert ds.MnistDataset(all_files_path, usage="train").get_dataset_size() == 60000
assert ds.MnistDataset(all_files_path, usage="test").get_dataset_size() == 10000
assert ds.MnistDataset(all_files_path, usage="all").get_dataset_size() == 70000
if __name__ == '__main__':
test_mnist_content_check()
test_mnist_basic()
......@@ -236,3 +271,4 @@ if __name__ == '__main__':
test_mnist_sequential_sampler()
test_mnist_exception()
test_mnist_visualize(plot=True)
test_mnist_usage()
......@@ -21,7 +21,7 @@ TARGET_SHAPE = [680, 680, 680, 680, 642, 607, 561, 596, 612, 680]
def test_voc_segmentation():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
num = 0
for item in data1.create_dict_iterator(num_epochs=1):
assert item["image"].shape[0] == IMAGE_SHAPE[num]
......@@ -31,7 +31,7 @@ def test_voc_segmentation():
def test_voc_detection():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
num = 0
count = [0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator(num_epochs=1):
......@@ -45,7 +45,7 @@ def test_voc_detection():
def test_voc_class_index():
class_index = {'car': 0, 'cat': 1, 'train': 5}
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", class_indexing=class_index, decode=True)
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True)
class_index1 = data1.get_class_indexing()
assert (class_index1 == {'car': 0, 'cat': 1, 'train': 5})
data1 = data1.shuffle(4)
......@@ -63,7 +63,7 @@ def test_voc_class_index():
def test_voc_get_class_indexing():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True)
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", decode=True)
class_index1 = data1.get_class_indexing()
assert (class_index1 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5})
data1 = data1.shuffle(4)
......@@ -81,7 +81,7 @@ def test_voc_get_class_indexing():
def test_case_0():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True)
resize_op = vision.Resize((224, 224))
......@@ -99,7 +99,7 @@ def test_case_0():
def test_case_1():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True)
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", decode=True)
resize_op = vision.Resize((224, 224))
......@@ -116,7 +116,7 @@ def test_case_1():
def test_case_2():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True)
sizes = [0.5, 0.5]
randomize = False
dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize)
......@@ -134,7 +134,7 @@ def test_case_2():
def test_voc_exception():
try:
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True)
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", usage="train", decode=True)
for _ in data1.create_dict_iterator(num_epochs=1):
pass
assert False
......@@ -142,7 +142,7 @@ def test_voc_exception():
pass
try:
data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", class_indexing={"cat": 0}, decode=True)
data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", class_indexing={"cat": 0}, decode=True)
for _ in data2.create_dict_iterator(num_epochs=1):
pass
assert False
......@@ -150,7 +150,7 @@ def test_voc_exception():
pass
try:
data3 = ds.VOCDataset(DATA_DIR, task="Detection", mode="notexist", decode=True)
data3 = ds.VOCDataset(DATA_DIR, task="Detection", usage="notexist", decode=True)
for _ in data3.create_dict_iterator(num_epochs=1):
pass
assert False
......@@ -158,7 +158,7 @@ def test_voc_exception():
pass
try:
data4 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnotexist", decode=True)
data4 = ds.VOCDataset(DATA_DIR, task="Detection", usage="xmlnotexist", decode=True)
for _ in data4.create_dict_iterator(num_epochs=1):
pass
assert False
......@@ -166,7 +166,7 @@ def test_voc_exception():
pass
try:
data5 = ds.VOCDataset(DATA_DIR, task="Detection", mode="invalidxml", decode=True)
data5 = ds.VOCDataset(DATA_DIR, task="Detection", usage="invalidxml", decode=True)
for _ in data5.create_dict_iterator(num_epochs=1):
pass
assert False
......@@ -174,7 +174,7 @@ def test_voc_exception():
pass
try:
data6 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnoobject", decode=True)
data6 = ds.VOCDataset(DATA_DIR, task="Detection", usage="xmlnoobject", decode=True)
for _ in data6.create_dict_iterator(num_epochs=1):
pass
assert False
......
......@@ -35,6 +35,7 @@ def diff_mse(in1, in2):
mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
return mse * 100
def test_cifar10():
"""
dataset parameter
......@@ -45,7 +46,7 @@ def test_cifar10():
batch_size = 32
limit_dataset = 100
# apply dataset operations
data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset)
data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
data1 = data1.repeat(num_repeat)
data1 = data1.batch(batch_size, True)
num_epoch = 5
......@@ -139,6 +140,7 @@ def test_generator_dict_0():
np.testing.assert_array_equal(item["data"], golden)
i = i + 1
def test_generator_dict_1():
"""
test generator dict 1
......@@ -158,6 +160,7 @@ def test_generator_dict_1():
i = i + 1
assert i == 64
def test_generator_dict_2():
"""
test generator dict 2
......@@ -180,6 +183,7 @@ def test_generator_dict_2():
assert item1
# rely on garbage collector to destroy iter1
def test_generator_dict_3():
"""
test generator dict 3
......@@ -226,6 +230,7 @@ def test_generator_dict_4():
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_4_1():
"""
test generator dict 4_1
......@@ -249,6 +254,7 @@ def test_generator_dict_4_1():
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_4_2():
"""
test generator dict 4_2
......@@ -274,6 +280,7 @@ def test_generator_dict_4_2():
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_5():
"""
test generator dict 5
......@@ -305,6 +312,7 @@ def test_generator_dict_5():
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test tuple iterator
def test_generator_tuple_0():
......@@ -323,6 +331,7 @@ def test_generator_tuple_0():
np.testing.assert_array_equal(item[0], golden)
i = i + 1
def test_generator_tuple_1():
"""
test generator tuple 1
......@@ -342,6 +351,7 @@ def test_generator_tuple_1():
i = i + 1
assert i == 64
def test_generator_tuple_2():
"""
test generator tuple 2
......@@ -364,6 +374,7 @@ def test_generator_tuple_2():
assert item1
# rely on garbage collector to destroy iter1
def test_generator_tuple_3():
"""
test generator tuple 3
......@@ -442,6 +453,7 @@ def test_generator_tuple_5():
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test with repeat
def test_generator_tuple_repeat_1():
"""
......@@ -536,6 +548,7 @@ def test_generator_tuple_repeat_repeat_2():
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
def test_generator_tuple_repeat_repeat_3():
"""
test generator tuple repeat repeat 3
......
......@@ -149,7 +149,7 @@ def test_get_column_name_to_device():
def test_get_column_name_voc():
data = ds.VOCDataset(VOC_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False)
assert data.get_col_names() == ["image", "target"]
......
......@@ -22,7 +22,7 @@ DATA_DIR = "../data/dataset/testVOC2012"
def test_noop_pserver():
os.environ['MS_ROLE'] = 'MS_PSERVER'
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
num = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num += 1
......@@ -32,7 +32,7 @@ def test_noop_pserver():
def test_noop_sched():
os.environ['MS_ROLE'] = 'MS_SCHED'
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
num = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num += 1
......
......@@ -42,8 +42,8 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False):
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
......@@ -108,8 +108,8 @@ def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False):
logger.info("test_random_resized_crop_with_bbox_op_edge_c")
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
......@@ -142,7 +142,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c():
logger.info("test_random_resized_crop_with_bbox_op_invalid_c")
# Load dataset, only Augmented Dataset as test will raise ValueError
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
try:
# If input range of scale is not in the order of (min, max), ValueError will be raised.
......@@ -168,7 +168,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c():
"""
logger.info("test_random_resized_crop_with_bbox_op_invalid2_c")
# Load dataset # only loading the to AugDataset as test will fail on this
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
try:
# If input range of ratio is not in the order of (min, max), ValueError will be raised.
......@@ -195,13 +195,13 @@ def test_random_resized_crop_with_bbox_op_bad_c():
logger.info("test_random_resized_crop_with_bbox_op_bad_c")
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
......
......@@ -39,8 +39,8 @@ def test_random_crop_with_bbox_op_c(plot_vis=False):
logger.info("test_random_crop_with_bbox_op_c")
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
# define test OP with values to match existing Op UT
test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200])
......@@ -101,8 +101,8 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
# define test OP with values to match existing Op unit - test
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255))
......@@ -138,8 +138,8 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False):
logger.info("test_random_crop_with_bbox_op3_c")
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
# define test OP with values to match existing Op unit - test
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
......@@ -168,8 +168,8 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
logger.info("test_random_crop_with_bbox_op_edge_c")
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
# define test OP with values to match existing Op unit - test
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
......@@ -205,7 +205,7 @@ def test_random_crop_with_bbox_op_invalid_c():
logger.info("test_random_crop_with_bbox_op_invalid_c")
# Load dataset
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
try:
# define test OP with values to match existing Op unit - test
......@@ -231,13 +231,13 @@ def test_random_crop_with_bbox_op_bad_c():
logger.info("test_random_crop_with_bbox_op_bad_c")
test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200])
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
......@@ -247,7 +247,7 @@ def test_random_crop_with_bbox_op_bad_padding():
"""
logger.info("test_random_crop_with_bbox_op_invalid_c")
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
try:
test_op = c_vision.RandomCropWithBBox([512, 512], padding=-1)
......
......@@ -37,11 +37,9 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False):
logger.info("test_random_horizontal_flip_with_bbox_op_c")
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomHorizontalFlipWithBBox(1)
......@@ -102,11 +100,9 @@ def test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False):
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomHorizontalFlipWithBBox(0.6)
......@@ -140,8 +136,8 @@ def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False):
"""
logger.info("test_horizontal_flip_with_bbox_valid_edge_c")
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomHorizontalFlipWithBBox(1)
......@@ -178,7 +174,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
"""
logger.info("test_random_horizontal_bbox_invalid_prob_c")
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
try:
# Note: Valid range of prob should be [0.0, 1.0]
......@@ -201,13 +197,13 @@ def test_random_horizontal_flip_with_bbox_invalid_bounds_c():
test_op = c_vision.RandomHorizontalFlipWithBBox(1)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features")
......
......@@ -39,11 +39,9 @@ def test_random_resize_with_bbox_op_voc_c(plot_vis=False):
original_seed = config_get_set_seed(123)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomResizeWithBBox(100)
......@@ -120,11 +118,9 @@ def test_random_resize_with_bbox_op_edge_c(plot_vis=False):
box has dimensions as the image itself.
"""
logger.info("test_random_resize_with_bbox_op_edge_c")
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomResizeWithBBox(500)
......@@ -197,13 +193,13 @@ def test_random_resize_with_bbox_op_bad_c():
logger.info("test_random_resize_with_bbox_op_bad_c")
test_op = c_vision.RandomResizeWithBBox((400, 300))
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
......
......@@ -37,11 +37,9 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False):
"""
logger.info("test_random_vertical_flip_with_bbox_op_c")
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomVerticalFlipWithBBox(1)
......@@ -102,11 +100,9 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomVerticalFlipWithBBox(0.8)
......@@ -139,11 +135,9 @@ def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False):
applied on dynamically generated edge case, expected to pass
"""
logger.info("test_random_vertical_flip_with_bbox_op_edge_c")
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.RandomVerticalFlipWithBBox(1)
......@@ -174,8 +168,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError
"""
logger.info("test_random_vertical_flip_with_bbox_op_invalid_c")
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
try:
test_op = c_vision.RandomVerticalFlipWithBBox(2)
......@@ -201,13 +194,13 @@ def test_random_vertical_flip_with_bbox_op_bad_c():
logger.info("test_random_vertical_flip_with_bbox_op_bad_c")
test_op = c_vision.RandomVerticalFlipWithBBox(1)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
......
......@@ -39,11 +39,9 @@ def test_resize_with_bbox_op_voc_c(plot_vis=False):
logger.info("test_resize_with_bbox_op_voc_c")
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.ResizeWithBBox(100)
......@@ -110,11 +108,9 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False):
box has dimensions as the image itself.
"""
logger.info("test_resize_with_bbox_op_edge_c")
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
test_op = c_vision.ResizeWithBBox(500)
......@@ -163,13 +159,13 @@ def test_resize_with_bbox_op_bad_c():
logger.info("test_resize_with_bbox_op_bad_c")
test_op = c_vision.ResizeWithBBox((200, 300))
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features")
......
......@@ -32,6 +32,7 @@ from mindspore.dataset.vision import Inter
def test_imagefolder(remove_json_files=True):
"""
Test simulating resnet50 dataset pipeline.
......@@ -103,7 +104,7 @@ def test_mnist_dataset(remove_json_files=True):
data_dir = "../data/dataset/testMnistData"
ds.config.set_seed(1)
data1 = ds.MnistDataset(data_dir, 100)
data1 = ds.MnistDataset(data_dir, num_samples=100)
one_hot_encode = c.OneHot(10) # num_classes is input argument
data1 = data1.map(input_columns="label", operations=one_hot_encode)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册