提交 c79db93c 编写于 作者: E Eric

Initial commit for album

Added linter fix for album dataset

Added testDataset

Adding signature

Added JsonDataset example API

Example dataset

Resolving format

More fixing

Refactor

Small fix

Added compiling album dataset

Running tests

Added linter fix #1

Passing UT

Added dataset API

Addressing clang

Clang part 2

Fixing pass

Fixed tree check

lint fix

Added lint fix part 2
上级 e06dfaa8
......@@ -393,7 +393,7 @@ build_mindspore()
CMAKE_VERBOSE="--verbose"
fi
cmake --build . --target package ${CMAKE_VERBOSE} -j$THREAD_NUM
echo "success to build mindspore project!"
echo "success building mindspore project!"
}
checkndk() {
......
......@@ -21,6 +21,7 @@
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/engine/dataset_iterator.h"
// Source dataset headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
......@@ -117,6 +118,15 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) {
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
// (In alphabetical order)
// Function to create a AlbumDataset.
std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<AlbumDataset>(dataset_dir, data_schema, column_names, decode, sampler);
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a CelebADataset.
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type,
const std::shared_ptr<SamplerObj> &sampler, bool decode,
......@@ -687,6 +697,49 @@ bool ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_sha
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
// (In alphabetical order)
// Constructor for AlbumDataset
AlbumDataset::AlbumDataset(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode,
const std::shared_ptr<SamplerObj> &sampler)
: dataset_dir_(dataset_dir),
schema_path_(data_schema),
column_names_(column_names),
decode_(decode),
sampler_(sampler) {}
bool AlbumDataset::ValidateParams() {
if (!ValidateDatasetDirParam("AlbumDataset", dataset_dir_)) {
return false;
}
if (!ValidateDatasetFilesParam("AlbumDataset", {schema_path_})) {
return false;
}
return true;
}
// Function to build AlbumDataset
std::vector<std::shared_ptr<DatasetOp>> AlbumDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// If user does not specify Sampler, create a default sampler, i.e., RandomSampler.
if (sampler_ == nullptr) {
sampler_ = CreateDefaultSampler();
}
auto schema = std::make_unique<DataSchema>();
RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_));
// Argument that is not exposed to user in the API.
std::set<std::string> extensions = {};
node_ops.push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, extensions, std::move(schema), std::move(sampler_->Build())));
return node_ops;
}
// Constructor for CelebADataset
CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
......
......@@ -13,6 +13,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
text_file_op.cc
clue_op.cc
csv_op.cc
album_op.cc
)
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_
#include <deque>
#include <memory>
#include <queue>
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include <utility>
#include <vector>
#include <unordered_map>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/services.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
// Forward declares
template <typename T>
class Queue;
// Define row information as a list of file objects to read
using FolderImages = std::shared_ptr<std::pair<std::string, std::queue<std::string>>>;
/// \class AlbumOp album_op.h
class AlbumOp : public ParallelOp, public RandomAccessOp {
public:
class Builder {
public:
/// \brief Constructor for Builder class of AlbumOp
Builder();
/// \brief Destructor.
~Builder() = default;
/// \brief Setter method
/// \param[in] rows_per_buffer
/// \return Builder setter method returns reference to the builder
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
/// \brief Setter method
/// \param[in] size
/// \return Builder setter method returns reference to the builder
Builder &SetOpConnectorSize(int32_t size) {
builder_op_connector_size_ = size;
return *this;
}
/// \brief Setter method
/// \param[in] exts - file extensions to be read
/// \return Builder setter method returns reference to the builder
Builder &SetExtensions(const std::set<std::string> &exts) {
builder_extensions_ = exts;
return *this;
}
/// \brief Setter method
/// \param[in] do_decode
/// \return Builder setter method returns reference to the builder
Builder &SetDecode(bool do_decode) {
builder_decode_ = do_decode;
return *this;
}
/// \brief Setter method
/// \param[in] num_workers
/// \return Builder setter method returns reference to the builder
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
/// \brief Setter method
/// \param[in] sampler
/// \return Builder setter method returns reference to the builder
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
/// \brief Setter method
/// \param[in] dir - dataset directory
/// \return Builder setter method returns reference to the builder
Builder &SetAlbumDir(const std::string &dir) {
builder_dir_ = dir;
return *this;
}
/// \brief Setter method
/// \param[in] file - schema file to load
/// \return Builder setter method returns reference to the builder
Builder &SetSchemaFile(const std::string &file) {
builder_schema_file_ = file;
return *this;
}
/// \brief Setter method
/// \param[in] columns - input columns
/// \return Builder setter method returns reference to the builder
Builder &SetColumnsToLoad(const std::vector<std::string> &columns) {
builder_columns_to_load_ = columns;
return *this;
}
/// \brief Check validity of input args
/// \return - The error code return
Status SanityCheck();
/// \brief The builder "build" method creates the final object.
/// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp
/// \return - The error code return
Status Build(std::shared_ptr<AlbumOp> *op);
private:
bool builder_decode_;
std::vector<std::string> builder_columns_to_load_;
std::string builder_dir_;
std::string builder_schema_file_;
int32_t builder_num_workers_;
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
};
/// \brief Constructor
/// \param[in] num_wkrs - Num of workers reading images in parallel
/// \param[in] rows_per_buffer Number of images (rows) in each buffer
/// \param[in] file_dir - directory of Album
/// \param[in] queue_size - connector size
/// \param[in] do_decode - decode image files
/// \param[in] exts - set of file extensions to read, if empty, read everything under the dir
/// \param[in] data_schema - schema of dataset
/// \param[in] sampler - sampler tells AlbumOp what to read
AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
/// \brief Destructor.
~AlbumOp() = default;
/// \brief Initialize AlbumOp related var, calls the function to walk all files
/// \return - The error code return
Status PrescanEntry();
/// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
/// \param[in] int32_t workerId - id of each worker
/// \return Status - The error code return
Status WorkerEntry(int32_t worker_id) override;
/// \brief Main Loop of AlbumOp
/// Master thread: Fill IOBlockQueue, then goes to sleep
/// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
/// \return Status - The error code return
Status operator()() override;
/// \brief A print method typically used for debugging
/// \param[in] out
/// \param[in] show_all
void Print(std::ostream &out, bool show_all) const override;
/// \brief Check if image ia valid.Only support JPEG/PNG/GIF/BMP
/// This function could be optimized to return the tensor to reduce open/closing files
/// \return Status - The error code return
Status CheckImageType(const std::string &file_name, bool *valid);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "AlbumOp"; }
private:
/// \brief Initialize Sampler, calls sampler->Init() within
/// \return Status The error code return
Status InitSampler();
/// \brief Load image to tensor row
/// \param[in] image_file Image name of file
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code return
Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row);
/// \brief Load vector of ints to tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing multi-dimensional label
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code return
Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load string array into a tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing string tensor
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code return
Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load string into a tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing string tensor
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code return
Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load float value to tensor row
/// \param[in] json_obj Json object containing float
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code return
Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load emtpy tensor to tensor row
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code return
Status LoadEmptyTensor(uint32_t col_num, TensorRow *row);
/// \brief Load id from file name to tensor row
/// \param[in] file The file name to get ID from
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code return
Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row);
/// \brief Load a tensor row according to a json file
/// \param[in] ImageColumns file Json file location
/// \param[inout] TensorRow row Json content stored into a tensor row
/// \return Status The error code return
Status LoadTensorRow(const std::string &file, TensorRow *row);
/// \param[in] const std::vector<int64_t> &keys Keys in ioblock
/// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to
/// \return Status The error code return
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
/// \brief Called first when function is called
/// \return The error code return
Status LaunchThreadsAndInitOp();
/// \brief reset Op
/// \return Status The error code return
Status Reset() override;
// Private function for computing the assignment of the column name map.
// @return - Status
Status ComputeColMap() override;
int32_t rows_per_buffer_;
std::string folder_path_; // directory of image folder
bool decode_;
std::set<std::string> extensions_; // extensions allowed
std::unordered_map<std::string, int32_t> col_name_map_;
std::unique_ptr<DataSchema> data_schema_;
std::shared_ptr<Sampler> sampler_;
int64_t row_cnt_;
int64_t buf_cnt_;
int64_t sampler_ind_;
int64_t dirname_offset_;
WaitPost wp_;
std::vector<std::string> image_rows_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_
......@@ -134,7 +134,6 @@ Status ImageFolderOp::operator()() {
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64");
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
keys.push_back(*itr);
......
......@@ -30,6 +30,7 @@
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
......@@ -199,6 +200,11 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified)
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
......
......@@ -49,6 +49,8 @@ class FilterOp;
class GeneratorOp;
#endif
class AlbumOp;
class RandomDataOp;
class RepeatOp;
......@@ -178,6 +180,8 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
......
......@@ -21,6 +21,7 @@
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
......@@ -152,6 +153,11 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> n
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
......
......@@ -79,6 +79,12 @@ class CacheTransformPass : public TreePass {
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
......
......@@ -111,5 +111,11 @@ Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modifie
std::cout << "Visiting ImageFolderOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ImageFolderOp" << '\n';
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -58,6 +58,8 @@ class PrinterPass : public NodePass {
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;
};
} // namespace dataset
......
......@@ -48,6 +48,7 @@ class TensorOperation;
class SchemaObj;
class SamplerObj;
// Datasets classes (in alphabetical order)
class AlbumDataset;
class CelebADataset;
class Cifar10Dataset;
class Cifar100Dataset;
......@@ -79,13 +80,27 @@ class ZipDataset;
/// \return Shared pointer to the current schema
std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "");
/// \brief Function to create an AlbumDataset
/// \notes The generated dataset is specified through setting a schema
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] data_schema Path to dataset schema file
/// \param[in] column_names Column names used to specify columns to load, if empty, will read all columns.
/// (default = {})
/// \param[in] decode the option to decode the images in dataset (default = false)
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
/// A `RandomSampler` will be used to randomly iterate the entire dataset (default = nullptr)
/// \return Shared pointer to the current Dataset
std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names = {}, bool decode = false,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a CelebADataset
/// \notes The generated dataset has two columns ['image', 'attr'].
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'.
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// will be used to randomly iterate the entire dataset
/// \param[in] decode Decode the images after reading (default=false).
/// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
/// \return Shared pointer to the current Dataset
......@@ -97,7 +112,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
......@@ -106,7 +121,7 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
/// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
......@@ -114,19 +129,19 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
/// \brief Function to create a CLUEDataset
/// \notes The generated dataset has a variable number of columns depending on the task and usage
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// will be sorted in a lexicographical order.
/// \param[in] task The kind of task, one of "AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC" and "CSL" (default="AFQMC").
/// \param[in] usage Be used to "train", "test" or "eval" data (default="train").
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current CLUEDataset
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
const std::string &usage = "train", int64_t num_samples = 0,
......@@ -135,19 +150,19 @@ std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files,
/// \brief Function to create a CocoDataset
/// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
/// ['iscrowd', dtype=uint32]].
/// - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd', dtype=uint32]].
/// - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
/// ['num_keypoints', dtype=uint32]].
/// - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
/// ['iscrowd', dtype=uint32], ['area', dtype=uitn32]].
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
/// ['iscrowd', dtype=uint32]].
/// - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd', dtype=uint32]].
/// - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
/// ['num_keypoints', dtype=uint32]].
/// - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
/// ['iscrowd', dtype=uint32], ['area', dtype=uitn32]].
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] annotation_file Path to the annotation json
/// \param[in] task Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint'
/// \param[in] decode Decode the images after reading
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task = "Detection", const bool &decode = false,
......@@ -181,12 +196,12 @@ std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, c
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
/// All images within one folder have the same label
/// The generated dataset has two columns ['image', 'label']
/// All images within one folder have the same label
/// The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] decode A flag to decode in ImageFolder
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// \param[in] extensions File extensions to be read
/// \param[in] class_indexing a class name to label map
/// \return Shared pointer to the current ImageFolderDataset
......@@ -200,9 +215,9 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir,
/// \param[in] dataset_file The dataset file to be read
/// \param[in] usage Need "train", "eval" or "inference" data (default="train")
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder
/// names will be sorted alphabetically and each class will be given a unique index starting from 0).
/// names will be sorted alphabetically and each class will be given a unique index starting from 0).
/// \param[in] decode Decode the images after reading (default=false).
/// \return Shared pointer to the current ManifestDataset
std::shared_ptr<ManifestDataset> Manifest(std::string dataset_file, std::string usage = "train",
......@@ -214,7 +229,7 @@ std::shared_ptr<ManifestDataset> Manifest(std::string dataset_file, std::string
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current MnistDataset
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
......@@ -245,17 +260,17 @@ std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, T schem
/// \brief Function to create a TextFileDataset
/// \notes The generated dataset has one column ['text']
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// will be sorted in a lexicographical order.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current TextFileDataset
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
......@@ -263,16 +278,16 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase
/// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
/// ['difficult', dtype=uint32], ['truncate', dtype=uint32]].
/// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
/// ['difficult', dtype=uint32], ['truncate', dtype=uint32]].
/// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection"
/// \param[in] mode Set the data list txt file to be readed
/// \param[in] class_indexing A str-to-int mapping from label name to index
/// \param[in] decode Decode the images after reading
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
const std::string &mode = "train",
......@@ -335,9 +350,9 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \notes Combines batch_size number of consecutive rows into batches
/// \param[in] batch_size Path to the root directory that contains the dataset
/// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete
/// batch. If true, and if there are less than batch_size rows
/// available to make the last batch, then those rows will
/// be dropped and not propagated to the next node
/// batch. If true, and if there are less than batch_size rows
/// available to make the last batch, then those rows will
/// be dropped and not propagated to the next node
/// \return Shared pointer to the current BatchDataset
std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
......@@ -368,16 +383,16 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset
/// \param[in] operations Vector of operations to be applied on the dataset. Operations are
/// applied in the order they appear in this list
/// applied in the order they appear in this list
/// \param[in] input_columns Vector of the names of the columns that will be passed to the first
/// operation as input. The size of this list must match the number of
/// input columns expected by the first operator. The default input_columns
/// is the first column
/// operation as input. The size of this list must match the number of
/// input columns expected by the first operator. The default input_columns
/// is the first column
/// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
/// This parameter is mandatory if len(input_columns) != len(output_columns)
/// The size of this list must match the number of output columns of the
/// last operation. The default output_columns will have the same
/// name as the input columns, i.e., the columns will be replaced
/// This parameter is mandatory if len(input_columns) != len(output_columns)
/// The size of this list must match the number of output columns of the
/// last operation. The default output_columns will have the same
/// name as the input columns, i.e., the columns will be replaced
/// \param[in] project_columns A list of column names to project
/// \return Shared pointer to the current MapDataset
std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorOperation>> operations,
......@@ -404,7 +419,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \param[in] count Number of times the dataset should be repeated
/// \return Shared pointer to the current Dataset
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
/// due to a limitation in the current implementation
/// due to a limitation in the current implementation
std::shared_ptr<Dataset> Repeat(int32_t count = -1);
/// \brief Function to create a Shuffle Dataset
......@@ -506,6 +521,31 @@ class SchemaObj {
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
// (In alphabetical order)
class AlbumDataset : public Dataset {
public:
/// \brief Constructor
AlbumDataset(const std::string &dataset_dir, const std::string &data_schema,
const std::vector<std::string> &column_names, bool decode, const std::shared_ptr<SamplerObj> &sampler);
/// \brief Destructor
~AlbumDataset() = default;
/// \brief a base class override function to create a runtime dataset op object from this class
/// \return shared pointer to the newly created DatasetOp
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool ValidateParams() override;
private:
std::string dataset_dir_;
std::string schema_path_;
std::vector<std::string> column_names_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
class CelebADataset : public Dataset {
public:
/// \brief Constructor
......
......@@ -5,6 +5,7 @@ SET(DE_UT_SRCS
common/cvop_common.cc
common/bboxop_common.cc
auto_contrast_op_test.cc
album_op_test.cc
batch_op_test.cc
bit_functions_test.cc
storage_container_test.cc
......@@ -101,6 +102,7 @@ SET(DE_UT_SRCS
c_api_samplers_test.cc
c_api_transforms_test.cc
c_api_dataset_ops_test.cc
c_api_dataset_album_test.cc
c_api_dataset_cifar_test.cc
c_api_dataset_clue_test.cc
c_api_dataset_coco_test.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include "common/common.h"
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "securec.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/transforms.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::ERROR;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
std::shared_ptr<BatchOp> Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2);
std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
std::shared_ptr<AlbumOp> Album(int64_t num_works, int64_t rows, int64_t conns, std::string path,
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
bool decode = false) {
std::shared_ptr<AlbumOp> so;
AlbumOp::Builder builder;
Status rc = builder.SetNumWorkers(num_works)
.SetAlbumDir(path)
.SetRowsPerBuffer(rows)
.SetOpConnectorSize(conns)
.SetExtensions({".json"})
.SetSampler(std::move(sampler))
.SetDecode(decode)
.Build(&so);
return so;
}
std::shared_ptr<AlbumOp> AlbumSchema(int64_t num_works, int64_t rows, int64_t conns, std::string path,
std::string schema_file, std::vector<std::string> column_names = {},
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
bool decode = false) {
std::shared_ptr<AlbumOp> so;
AlbumOp::Builder builder;
Status rc = builder.SetNumWorkers(num_works)
.SetSchemaFile(schema_file)
.SetColumnsToLoad(column_names)
.SetAlbumDir(path)
.SetRowsPerBuffer(rows)
.SetOpConnectorSize(conns)
.SetExtensions({".json"})
.SetSampler(std::move(sampler))
.SetDecode(decode)
.Build(&so);
return so;
}
class MindDataTestAlbum : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) {
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
std::vector<std::string> column_names = {"image", "label", "id"};
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false), Repeat(2)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tensor_map;
di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
int32_t label = 0;
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "label shape"
<< tensor_map["label"] << "\n";
i++;
di.GetNextAsMap(&tensor_map);
}
MS_LOG(INFO) << "got rows" << i << "\n";
EXPECT_TRUE(i == 14);
}
}
TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaNoOrder) {
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tensor_map;
di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
int32_t label = 0;
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "label shape"
<< tensor_map["label"] << "\n";
i++;
di.GetNextAsMap(&tensor_map);
}
MS_LOG(INFO) << "got rows" << i << "\n";
EXPECT_TRUE(i == 14);
}
}
TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaFloat) {
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
// add the priority column
std::string schema_file = datasets_root_path_ + "/testAlbum/floatSchema.json";
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tensor_map;
di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
int32_t label = 0;
double priority = 0;
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
tensor_map["_priority"]->GetItemAt<double>(&priority, {});
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "label shape"
<< tensor_map["label"] << "priority: " << priority << "\n";
i++;
di.GetNextAsMap(&tensor_map);
}
MS_LOG(INFO) << "got rows" << i << "\n";
EXPECT_TRUE(i == 14);
}
}
TEST_F(MindDataTestAlbum, TestSequentialAlbumWithFullSchema) {
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
// add the priority column
std::string schema_file = datasets_root_path_ + "/testAlbum/fullSchema.json";
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tensor_map;
di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
int32_t label = 0;
double priority = 0;
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
tensor_map["_priority"]->GetItemAt<double>(&priority, {});
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "label shape"
<< tensor_map["label"] << "priority: " << priority << " embedding : " <<
tensor_map["_embedding"]->shape() << "\n";
i++;
di.GetNextAsMap(&tensor_map);
}
MS_LOG(INFO) << "got rows" << i << "\n";
EXPECT_TRUE(i == 14);
}
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestAlbumBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumBasic.";
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
std::vector<std::string> column_names = {"image", "label", "id"};
// Create a Album Dataset
std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 7);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestAlbumDecode) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumDecode.";
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
std::vector<std::string> column_names = {"image", "label", "id"};
// Create a Album Dataset
std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names, true);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
auto shape = image->shape();
MS_LOG(INFO) << "Tensor image shape size: " << shape.Size();
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
EXPECT_GT(shape.Size(), 1); // Verify decode=true took effect
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 7);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestAlbumNumSamplers) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumNumSamplers.";
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
std::vector<std::string> column_names = {"image", "label", "id"};
// Create a Album Dataset
std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names, true, SequentialSampler(0, 1));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 1);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestAlbumError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumError.";
std::string folder_path = datasets_root_path_ + "/testAlbum/ima";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
std::vector<std::string> column_names = {"image", "label", "id"};
// Create a Album Dataset
std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names, true, SequentialSampler(0, 1));
EXPECT_EQ(ds, nullptr);
}
......@@ -32,6 +32,8 @@ export GLOG_v=2
## prepare data for dataset & mindrecord
cp -fr $PROJECT_PATH/tests/ut/data ${PROJECT_PATH}/build/mindspore/tests/ut/cpp/
## prepare album dataset, uses absolute path so has to be generated
python ${PROJECT_PATH}/build/mindspore/tests/ut/cpp/data/dataset/testAlbum/gen_json.py
if [ $# -gt 0 ]; then
./ut_tests --gtest_filter=$1
......
{
"columns": {
"image": {
"type": "uint8",
"rank": 1
},
"label" : {
"type": "string",
"rank": 1
},
"id" : {
"type": "int64",
"rank": 0
}
}
}
......@@ -5,7 +5,7 @@
"rank": 1
},
"label" : {
"type": "int32",
"type": "string",
"rank": 1
},
"id" : {
......
......@@ -5,7 +5,7 @@
"rank": 1
},
"label" : {
"type": "int32",
"type": "string",
"rank": 1
},
"id" : {
......
......@@ -2,21 +2,21 @@ import json
import os
def dump_json_from_dict(structure, file_name):
with open(file_name + '.json', 'w') as file_path:
json.dump(structure, file_path)
with open(file_name + '.json', 'w') as fp:
json.dump(structure, fp)
if __name__ == '__main__':
# iterate over directory
DIRECTORY = "imagefolder"
i = 0
# iterate over DIRECTORY
DIRECTORY = os.path.dirname(os.path.realpath(__file__)) + "/original"
PARENT_DIR = os.path.dirname(DIRECTORY)
i = -1
for filename in os.listdir(DIRECTORY):
default_dict = {}
default_dict.update(dataset='')
default_dict.update(image=(os.path.join(DIRECTORY, filename)))
default_dict.update(label=[1, 2])
default_dict.update(image=os.path.abspath(os.path.join(DIRECTORY, filename)))
default_dict.update(label=['3', '2'])
default_dict.update(_priority=0.8)
default_dict.update(_embedding='sample.bin')
default_dict.update(_segmented_image=(os.path.join(DIRECTORY, filename)))
default_dict.update(_processed_image=(os.path.join(DIRECTORY, filename)))
default_dict.update(_embedding=os.path.abspath(os.path.join(PARENT_DIR, 'sample.bin')))
default_dict.update(_processed_image=os.path.abspath(os.path.join(DIRECTORY, filename)))
i = i + 1
dump_json_from_dict(default_dict, 'images/'+str(i))
dump_json_from_dict(default_dict, PARENT_DIR + '/images/'+str(i))
{"dataset": "", "image": "original/apple_expect_decoded.jpg", "label": ["3", "2"], "_priority": 0.8, "_embedding": "sample.bin", "_processed_image": "original/apple_expect_decoded.jpg"}
{"dataset": "", "image": "imagefolder/apple_expect_decoded.jpg", "label": [1, 2], "_priority": 0.8, "_embedding": "sample.bin", "_segmented_image": "imagefolder/apple_expect_decoded.jpg", "_processed_image": "imagefolder/apple_expect_decoded.jpg"}
\ No newline at end of file
{"dataset": "", "image": "testAlbum//testAlbum/original/apple_expect_resize_bilinear.jpg", "label": ["3", "2"], "_priority": 0.8, "_embedding": "testAlbum//testAlbum/sample.bin", "_processed_image": "testAlbum//testAlbum/original/apple_expect_resize_bilinear.jpg"}
{"dataset": "", "image": "imagefolder/apple_expect_resize_bilinear.jpg", "label": [1, 2], "_priority": 0.8, "_embedding": "sample.bin", "_segmented_image": "imagefolder/apple_expect_resize_bilinear.jpg", "_processed_image": "imagefolder/apple_expect_resize_bilinear.jpg"}
\ No newline at end of file
{"dataset": "", "image": "testAlbum//testAlbum/original/apple_expect_changemode.jpg", "label": ["3", "2"], "_priority": 0.8, "_embedding": "testAlbum//testAlbum/sample.bin", "_processed_image": "testAlbum//testAlbum/original/apple_expect_changemode.jpg"}
{"dataset": "", "image": "imagefolder/apple_expect_changemode.jpg", "label": [1, 2], "_priority": 0.8, "_embedding": "sample.bin", "_segmented_image": "imagefolder/apple_expect_changemode.jpg", "_processed_image": "imagefolder/apple_expect_changemode.jpg"}
\ No newline at end of file
{"dataset": "", "image": "testAlbum//testAlbum/original/apple_expect_not_flip.jpg", "label": ["3", "2"], "_priority": 0.8, "_embedding": "testAlbum//testAlbum/sample.bin", "_processed_image": "testAlbum//testAlbum/original/apple_expect_not_flip.jpg"}
{"dataset": "", "image": "imagefolder/apple_expect_not_flip.jpg", "label": [1, 2], "_priority": 0.8, "_embedding": "sample.bin", "_segmented_image": "imagefolder/apple_expect_not_flip.jpg", "_processed_image": "imagefolder/apple_expect_not_flip.jpg"}
\ No newline at end of file
{"dataset": "", "image": "testAlbum//testAlbum/original/apple_expect_flipped_horizontal.jpg", "label": ["3", "2"], "_priority": 0.8, "_embedding": "testAlbum//testAlbum/sample.bin", "_processed_image": "testAlbum//testAlbum/original/apple_expect_flipped_horizontal.jpg"}
{"dataset": "", "image": "imagefolder/apple_expect_flipped_horizontal.jpg", "label": [1, 2], "_priority": 0.8, "_embedding": "sample.bin", "_segmented_image": "imagefolder/apple_expect_flipped_horizontal.jpg", "_processed_image": "imagefolder/apple_expect_flipped_horizontal.jpg"}
\ No newline at end of file
{"dataset": "", "image": "testAlbum//testAlbum/original/apple_expect_rescaled.jpg", "label": ["3", "2"], "_priority": 0.8, "_embedding": "testAlbum//testAlbum/sample.bin", "_processed_image": "testAlbum//testAlbum/original/apple_expect_rescaled.jpg"}
{"dataset": "", "image": "imagefolder/apple_expect_rescaled.jpg", "label": [1, 2], "_priority": 0.8, "_embedding": "sample.bin", "_segmented_image": "imagefolder/apple_expect_rescaled.jpg", "_processed_image": "imagefolder/apple_expect_rescaled.jpg"}
\ No newline at end of file
{"dataset": "", "image": "testAlbum//testAlbum/original/apple_expect_flipped_vertical.jpg", "label": ["3", "2"], "_priority": 0.8, "_embedding": "testAlbum//testAlbum/sample.bin", "_processed_image": "testAlbum//testAlbum/original/apple_expect_flipped_vertical.jpg"}
{"dataset": "", "image": "imagefolder/apple_expect_flipped_vertical.jpg", "label": [1, 2], "_priority": 0.8, "_embedding": "sample.bin", "_segmented_image": "imagefolder/apple_expect_flipped_vertical.jpg", "_processed_image": "imagefolder/apple_expect_flipped_vertical.jpg"}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册