提交 270bf831 编写于 作者: J Jesse Lee

Random Data Op

上级 05676676
......@@ -28,6 +28,7 @@
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "mindrecord/include/shard_category.h"
......@@ -65,6 +66,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
......@@ -972,6 +974,45 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
return Status::OK();
}
Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
// Required arguments
RandomDataOp::Builder builder;
if (args["num_samples"].is_none()) {
std::string err_msg = "Error: num_samples is a required argument";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::vector<std::string> columns_to_load;
bool schema_exists = false;
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (key == "num_parallel_workers") {
(void)builder.SetNumWorkers(ToInt(value));
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "num_samples") {
(void)builder.SetTotalRows(ToInt(value));
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
}
}
if (schema_exists) {
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
if (args.contains("schema_file_path")) {
RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load));
} else {
RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load));
}
(void)builder.SetDataSchema(std::move(schema));
}
std::shared_ptr<RandomDataOp> op;
RETURN_IF_NOT_OK(builder.Build(&op));
*ptr = op;
return Status::OK();
}
int32_t DEPipeline::GetNumClasses() const { return num_classes_; }
Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
......
......@@ -60,6 +60,7 @@ enum OpName {
kCifar10,
kCifar100,
kCelebA,
kRandomData,
kTextFile
};
......@@ -142,6 +143,8 @@ class DEPipeline {
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
void PrintTree();
int32_t GetNumClasses() const;
......
......@@ -47,6 +47,7 @@
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
......@@ -489,6 +490,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("VOC", OpName::kVoc)
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile);
......
......@@ -466,5 +466,24 @@ Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
"\"columns\" node is required in the schema json file.");
return Status::OK();
}
// Loops through all columns in the schema and returns a map with the column
// name to column index number.
Status DataSchema::GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map) {
if (out_column_name_map == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"unexpected null output column name map.");
}
for (int32_t i = 0; i < col_descs_.size(); ++i) {
if (col_descs_[i].name().empty()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Constructing column name map from schema, but found empty column name.");
}
(*out_column_name_map)[col_descs_[i].name()] = i;
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -20,6 +20,7 @@
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <nlohmann/json.hpp>
#include "dataset/core/constants.h"
......@@ -180,6 +181,12 @@ class DataSchema {
static const char DEFAULT_DATA_SCHEMA_FILENAME[];
// Loops through all columns in the schema and returns a map with the column
// name to column index number.
// @param out_column_name_map - The output map of columns names to column index
// @return Status - The error code return
Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map);
private:
// Internal helper function. Parses the json schema file in any order and produces a schema that
// does not follow any particular order (json standard does not enforce any ordering protocol).
......
......@@ -17,6 +17,7 @@ add_library(engine-datasetops-source OBJECT
${FEATURE_SRCS}
manifest_op.cc
cifar_op.cc
random_data_op.cc
celeba_op.cc
text_file_op.cc
)
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/random_data_op.h"
#include <iomanip>
#include <random>
#include "dataset/engine/execution_tree.h"
#include "dataset/core/config_manager.h"
#include "dataset/util/random.h"
#include "dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
RandomDataOp::Builder::Builder()
: builder_data_schema_(nullptr),
builder_num_workers_(0),
builder_op_connector_size_(0),
builder_rows_per_buffer_(0),
builder_total_rows_(0) {
// Some arguments to the RandomDataOp have a default argument that is taken from the config.
// The user may override these defaults by using the builder set methods.
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
builder_num_workers_ = cfg->num_parallel_workers();
builder_op_connector_size_ = cfg->op_connector_size();
}
// The build method that produces the instantiated RandomDataOp as a shared pointer
Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) {
RETURN_IF_NOT_OK(SanityCheck());
*out_op = std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
builder_total_rows_, std::move(builder_data_schema_));
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random
// schema.
// See details of generateSchema function to learn what type of schema it will create.
if ((*out_op)->data_schema_ == nullptr) {
RETURN_IF_NOT_OK((*out_op)->GenerateSchema());
}
// Extract the column name mapping from the schema and save it in the class.
// This will be needed when constructing buffers.
RETURN_IF_NOT_OK((*out_op)->data_schema_->GetColumnNameMap(&((*out_op)->column_name_map_)));
return Status::OK();
}
// Check if the required parameters are set by the builder.
Status RandomDataOp::Builder::SanityCheck() const {
// There actually is no required arguments for the random data op at all.
// Some arguments are preset with global values from config, and if they are not given by the user
// then we create them randomly. Leaving this function here for consistency with other operators.
return Status::OK();
}
// Constructor for RandomDataOp
RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema)
: ParallelOp(num_workers, op_connector_size),
buffer_id_(0),
rows_per_buffer_(rows_per_buffer),
total_rows_(total_rows),
epoch_buffers_sent_(0),
guys_in_(0),
guys_out_(num_workers_),
eoe_worker_id_(0),
data_schema_(std::move(data_schema)) {
rand_gen_.seed(GetSeed()); // seed the random generator
// If total rows was not given, then randomly pick a number
if (total_rows_ == 0) {
total_rows_ = GenRandomInt(1, kMaxTotalRows);
}
// Everyone is already out from the sync area.
all_out_.Set();
}
// A print method typically used for debugging
void RandomDataOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <RandomDataOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << " [total rows: " << total_rows_ << "]\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nTotal_rows: " << total_rows_
<< "\nRows per buffer: " << rows_per_buffer_
<< "\nSchema:\n" << *data_schema_ << "\n\n";
}
}
// Helper function to produce a default/random schema if one didn't exist
Status RandomDataOp::GenerateSchema() {
if (data_schema_ != nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Generating a schema but one already exists!");
}
// To randomly create a schema, we need to choose:
// a) how many columns
// b) the type of each column
// c) the shape of each column (number of dimensions i.e. rank)
// d) the shape of each column (dimension values)
data_schema_ = std::make_unique<DataSchema>();
std::unique_ptr<TensorShape> newShape;
std::unique_ptr<ColDescriptor> newCol;
// Loop over the number of chosen columns
int32_t numColumns = GenRandomInt(1, kMaxNumColumns);
for (int32_t i = 0; i < numColumns; i++) {
// For each column:
// - choose a datatype
// - generate a shape that randomly chooses the number of dimensions and the dimension values.
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(0, kMaxDataType));
int32_t rank = GenRandomInt(1, kMaxRank);
std::vector<dsize_t> dims;
for (int32_t d = 0; d < rank; d++) {
// 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random
// 0 value to the unknown attribute if 0 is chosen
dsize_t dim_value = static_cast<dsize_t>(GenRandomInt(0, kMaxDimValue));
if (dim_value == 0) dim_value = TensorShape::kDimUnknown;
dims.push_back(dim_value);
}
newShape = std::make_unique<TensorShape>(dims);
// Create the column descriptor
std::string colName = "c" + std::to_string(i);
newCol = std::make_unique<ColDescriptor>(colName, DataType(newType), TensorImpl::kFlexible, rank,
newShape.get());
data_schema_->AddColumn(*newCol);
}
return Status::OK();
}
// Class functor operator () override.
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work.
Status RandomDataOp::operator()() {
// First, compute how many buffers we'll need to satisfy the total row count.
// The only reason we do this is for the purpose of throttling worker count if needed.
int64_t buffers_needed = total_rows_ / rows_per_buffer_;
if (total_rows_ % rows_per_buffer_ != 0) {
buffers_needed++;
}
// If the amount of workers we have exceeds the number of buffers to produce, then we'll have
// idle workers doing nothing. In that case, let's throttle the worker count.
if (num_workers_ > buffers_needed) {
MS_LOG(INFO) << "RandomDataOp throttling worker count from " << num_workers_ << "to " << buffers_needed;
num_workers_ = buffers_needed;
num_producers_ = num_workers_;
guys_out_ = num_workers_;
// The output connector was already created with a different worker count. We have to drop and recreate
// that connector.
DatasetOp::CreateConnector(num_producers_, num_workers_);
}
// Assign the number of rows to each worker in a round robin fashion.
worker_max_rows_.reserve(num_workers_);
worker_rows_packed_.reserve(num_workers_);
// init the counts to zero to start.
for (int32_t w = 0; w < num_workers_; w++) {
worker_max_rows_.push_back(0);
worker_rows_packed_.push_back(0);
}
// then assign round robin row counts
int32_t currentWorker = 0;
for (int64_t r = 0; r < total_rows_; r++) {
worker_max_rows_[currentWorker]++;
currentWorker = (currentWorker + 1) % num_workers_;
}
// Next, compute the total buffer count. This stat is needed during reset logic
for (int32_t w = 0; w < num_workers_; w++) {
int64_t worker_buffers = 0;
worker_buffers = worker_max_rows_[w] / rows_per_buffer_;
if (worker_max_rows_[w] % rows_per_buffer_ != 0) worker_buffers++;
epoch_buffers_sent_ += worker_buffers;
}
// For the connector to work, we need to target the correct worker channel for the eoe.
// This will initialize it for the first one. reset() handles for the rest of the epochs.
eoe_worker_id_ = epoch_buffers_sent_ % num_workers_;
epoch_buffers_sent_++; // Add the eoe buffer to the count for subsequent epochs
// RandomDataOp doesn't need the master thread to stay around. Kick off the workers and then master exits.
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&RandomDataOp::WorkerEntry, this, std::placeholders::_1)));
// required task group setup after launching workers
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(epoch_sync_wait_post_.Register(tree_->AllTasks()));
return Status::OK();
}
// Performs a synchronization between workers at the end of an epoch
Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) {
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " syncing at end of epoch";
// Sync on the guys_in counter
// We have to wait the last guy is out.
all_out_.Wait();
// If we are not in a repeat loop, or that was the last repeat already, then setup our exit
// condition from the master loop.
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
*quitting = true;
}
auto prev = guys_in_.fetch_add(1);
bool last_guy_in = (prev + 1) == num_workers_;
// If we are the last worker to hit this sync point, we have some extra tasks
if (last_guy_in) {
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker "
<< eoe_worker_id_;
// Prepare for sync
all_out_.Clear();
// Always flow eoe at the end
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(eoe_worker_id_, std::move(eoe_buffer)));
// If we're done then also flow the eof
if (*quitting) {
// The eof needs to be sent from the next sender in the round robin, so +1
int32_t eof_worker_id = (eoe_worker_id_ + 1) % num_workers_;
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " has no more epochs. sending eof as worker "
<< eof_worker_id;
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(eof_worker_id, std::move(eof_buffer)));
}
}
// Wait for the reset to wake us up if we're not quitting
if (!(*quitting)) {
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait.";
RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait());
prev = guys_out_.fetch_add(1);
bool last_guy_out = (prev + 1) == num_workers_;
// Last guy out will clear the wait post and set the row counts
if (last_guy_out) {
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " last guy out clearing wait post.";
epoch_sync_wait_post_.Clear();
guys_in_ = 0;
all_out_.Set();
}
}
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " epoch sync complete.";
return Status::OK();
}
// The entry point code for when workers are launched
Status RandomDataOp::WorkerEntry(int32_t worker_id) {
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entry";
// handshake with the master first to tell it we're alive
TaskManager::FindMe()->Post();
bool quitting = false;
std::unique_ptr<TensorQTable> new_tensor_table = nullptr;
// Loop until the quitting variable gets set to true
do {
// If we have not yet reached the row count for this worker then produce another record
if (worker_rows_packed_[worker_id] < worker_max_rows_[worker_id]) {
TensorRow new_row;
// Start a new tensor table if needed
if (new_tensor_table == nullptr) {
new_tensor_table = std::make_unique<TensorQTable>();
}
// Create the data for the row
RETURN_IF_NOT_OK(CreateRandomRow(worker_id, &new_row));
// Add the row to our table
new_tensor_table->push_back(std::move(new_row));
worker_rows_packed_[worker_id]++;
// If the tensor table is at capacity then it's time to send it to output
if (new_tensor_table->size() == rows_per_buffer_) {
RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table)));
}
} else {
// We've reached the total row count for this worker, so it's time for epoch sync.
// There is likely some records built but not sent yet, so take care of those first
// (this buffer will be smaller than rows_per_buffer)
if (new_tensor_table != nullptr && new_tensor_table->size() > 0) {
RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table)));
}
// Now, let's enter the epoch sync
RETURN_IF_NOT_OK(EpochSync(worker_id, &quitting));
}
} while (!quitting);
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is now quitting.";
return Status::OK();
}
// A helper function to stuff the tensor table into a buffer and send it to output connector
Status RandomDataOp::PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table) {
auto new_buffer = std::make_unique<DataBuffer>(GetNextBufferId(), DataBuffer::kDeBFlagNone);
new_buffer->set_tensor_table(std::move(in_table));
new_buffer->set_column_name_map(column_name_map_);
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(new_buffer)));
return Status::OK();
}
// A helper function to create random data for the row
Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) {
if (new_row == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Missing tensor row output");
}
// Create a tensor for each column, then add the tensor to the row
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
const ColDescriptor current_col = data_schema_->column(i);
std::vector<dsize_t> current_shape = current_col.shape().AsVector();
std::unique_ptr<TensorShape> new_shape = nullptr;
std::unique_ptr<unsigned char[]> buf = nullptr;
std::shared_ptr<Tensor> new_tensor = nullptr;
// We need to resolve the shape to fill in any unknown dimensions with random
// values, then use that as our shape for this tensor.
for (int j = 0; j < current_shape.size(); ++j) {
if (current_shape[j] == TensorShape::kDimUnknown) {
current_shape[j] = static_cast<dsize_t>(GenRandomInt(1, kMaxDimValue));
}
}
new_shape = std::make_unique<TensorShape>(current_shape);
int64_t size_in_bytes = new_shape->NumOfElements() * current_col.type().SizeInBytes();
// Generate a random byte of data. This may cause some funny data for things like doubles,floats, bools
// however the random data op is not too concerned about the physical data itself.
std::uniform_int_distribution<uint8_t> uniDist(0, 255);
uint8_t random_byte = uniDist(rand_gen_);
// Now, create a chunk of memory for the entire tensor and copy this byte in repeatedly.
buf = std::make_unique<unsigned char[]>(size_in_bytes);
int ret_code = memset_s(buf.get(), size_in_bytes, random_byte, size_in_bytes);
if (ret_code != 0) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to set random bytes for a tensor.");
}
RETURN_IF_NOT_OK(
Tensor::CreateTensor(&new_tensor, current_col.tensorImpl(), *new_shape, current_col.type(), buf.get()));
// Add this tensor to the tensor row for output
(*new_row).push_back(std::move(new_tensor));
}
return Status::OK();
}
// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
Status RandomDataOp::Reset() {
MS_LOG(INFO) << "RandomDataOp resetting.";
// Ensure all guys are in the waitpost
if (guys_in_ != num_workers_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Issuing a reset, but some workers are missing from epochSync!");
}
// reset the row counters for all workers
for (int32_t w = 0; w < num_workers_; w++) {
worker_rows_packed_[w] = 0;
worker_max_rows_[w] = 0;
}
buffer_id_ = 0;
// Re-assign round robin row counts, starting from the worker after the one that gave
// the eoe last time
int32_t currentWorker = (eoe_worker_id_ + 1) % num_workers_;
for (int64_t r = 0; r < total_rows_; r++) {
worker_max_rows_[currentWorker]++;
currentWorker = (currentWorker + 1) % num_workers_;
}
// Compute which worker should get the eoe for the next epoch
eoe_worker_id_ = ((epoch_buffers_sent_ % num_workers_) + eoe_worker_id_) % num_workers_;
// Wake up the workers to get them going again in a new epoch
guys_out_ = 0;
epoch_sync_wait_post_.Set();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_
#define DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_
#include <atomic>
#include <map>
#include <memory>
#include <mutex>
#include <random>
#include <string>
#include <vector>
#include <unordered_map>
#include <utility>
#include "dataset/util/status.h"
#include "dataset/core/tensor.h"
#include "dataset/core/data_type.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
// The RandomDataOp is a leaf node storage operator that generates random data based
// on the schema specifications. Typically, it's used for testing and demonstrating
// various dataset operator pipelines. It is not "real" data to train with.
// The data that is random created is just random and repeated bytes, there is no
// "meaning" behind what these bytes are.
class RandomDataOp : public ParallelOp {
public:
// Some constants to provide limits to random generation.
static constexpr int32_t kMaxNumColumns = 4;
static constexpr int32_t kMaxRank = 4;
static constexpr int32_t kMaxDimValue = 2048;
static constexpr int32_t kMaxDataType = (DataType::DE_UNKNOWN - 1);
static constexpr int32_t kMaxTotalRows = 1024;
// A nested builder class to aid in the construction of a RandomDataOp
class Builder {
public:
/**
* Builder constructor. Creates the builder object.
* @note No default args.
* @return This is a constructor.
*/
Builder();
/**
* Default destructor
*/
~Builder() = default;
/**
* The build method that produces the instantiated RandomDataOp as a shared pointer
* @param out_op - The output RandomDataOperator that was constructed
* @return Status - The error code return
*/
Status Build(std::shared_ptr<RandomDataOp> *out_op);
/**
* Builder set method
* @param data_schema - A user-provided schema
* @return Builder - The modified builder by reference
*/
Builder &SetDataSchema(std::unique_ptr<DataSchema> data_schema) {
builder_data_schema_ = std::move(data_schema);
return *this;
}
/**
* Builder set method
* @param num_workers - The number of workers
* @return Builder - The modified builder by reference
*/
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
/**
* Builder set method
* @param op_connector_size - The size of the output connector
* @return Builder - The modified builder by reference
*/
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
/**
* Builder set method
* @param rows_per_buffer - The number of rows in each DataBuffer
* @return Builder - The modified builder by reference
*/
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
/**
* Builder set method
* @param total_rows - The total number of rows in the dataset
* @return Builder - The modified builder by reference
*/
Builder &SetTotalRows(int64_t total_rows) {
builder_total_rows_ = total_rows;
return *this;
}
private:
/**
* Check if the required parameters are set by the builder.
* @return Status - The error code return
*/
Status SanityCheck() const;
std::unique_ptr<DataSchema> builder_data_schema_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
int64_t builder_total_rows_;
}; // class Builder
/**
* Constructor for RandomDataOp
* @note Private constructor. Must use builder to construct.
* @param num_workers - The number of workers
* @param op_connector_size - The size of the output connector
* @param rows_per_buffer - The number of rows in each DataBuffer
* @param data_schema - A user-provided schema
* @param total_rows - The total number of rows in the dataset
* @return Builder - The modified builder by reference
*/
RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema);
/**
* Destructor
*/
~RandomDataOp() = default;
/**
* A print method typically used for debugging
* @param out - The output stream to write output to
* @param show_all - A bool to control if you want to show all info or just a summary
*/
void Print(std::ostream &out, bool show_all) const override;
/**
* << Stream output operator overload
* @notes This allows you to write the debug print info using stream operators
* @param out - reference to the output stream being overloaded
* @param so - reference to the ShuffleOp to display
* @return - the output stream must be returned
*/
friend std::ostream &operator<<(std::ostream &out, const RandomDataOp &op) {
op.Print(out, false);
return out;
}
/**
* Class functor operator () override.
* All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
* provide the master loop that drives the logic for performing the work.
* @return Status - The error code return
*/
Status operator()() override;
/**
* Overrides base class reset method. When an operator does a reset, it cleans up any state
* info from it's previous execution and then initializes itself so that it can be executed
* again.
* @return Status - The error code return
*/
Status Reset() override;
/**
* Quick getter for total rows.
*/
int64_t GetTotalRows() const { return total_rows_; }
private:
/**
* The entry point code for when workers are launched
* @param worker_id - The worker id
* @return Status - The error code return
*/
Status WorkerEntry(int32_t worker_id) override;
/**
* Helper function to produce a default/random schema if one didn't exist
@return Status - The error code return
*/
Status GenerateSchema();
/**
* Performs a synchronization between workers at the end of an epoch
* @param worker_id - The worker id
* @return Status - The error code return
*/
Status EpochSync(int32_t worker_id, bool *quitting);
/**
* A helper function to stuff the tensor table into a buffer and send it to output connector
* @param worker_id - The worker id
* @param in_table - The tensor table to pack and send
* @return Status - The error code return
*/
Status PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table);
/**
* A helper function to create random data for the row
* @param worker_id - The worker id
* @param new_row - The output row to produce
* @return Status - The error code return
*/
Status CreateRandomRow(int32_t worker_id, TensorRow *new_row);
/**
* A quick inline for producing a random number between (and including) min/max
* @param min - minimum number that can be generated
* @param max - maximum number that can be generated
* @return - The generated random number
*/
inline int32_t GenRandomInt(int32_t min, int32_t max) {
std::uniform_int_distribution<int32_t> uniDist(min, max);
return uniDist(rand_gen_);
}
/**
* A quick inline for producing the next buffer id in sequence, threadsafe
* @return - The next buffer id.
*/
inline int32_t GetNextBufferId() {
std::unique_lock<std::mutex> lock(buffer_id_mutex_);
return ++buffer_id_;
}
int32_t buffer_id_;
int64_t rows_per_buffer_;
int64_t total_rows_;
int64_t epoch_buffers_sent_;
std::atomic<int32_t> guys_in_;
std::atomic<int32_t> guys_out_;
int32_t eoe_worker_id_;
std::unique_ptr<DataSchema> data_schema_;
std::vector<int64_t> worker_max_rows_;
std::vector<int64_t> worker_rows_packed_;
std::unordered_map<std::string, int32_t> column_name_map_;
std::mt19937 rand_gen_;
WaitPost epoch_sync_wait_post_;
WaitPost all_out_;
std::mutex buffer_id_mutex_;
}; // class RandomDataOp
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_
......@@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
from .core.configuration import config
from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
Schema, Shuffle, zip
Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show
......
......@@ -3146,6 +3146,57 @@ class Cifar100Dataset(SourceDataset):
return get_num_rows(num_rows, self.num_shards)
class RandomDataset(SourceDataset):
"""
A source dataset that generates random data.
Args:
num_samples (int): number of samples to generate.
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the meta data from the TFRecord file is considered the schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
"""
def __init__(self, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None):
super().__init__(num_parallel_workers)
schema_obj = None
if (schema is not None) and (not isinstance(schema, Schema)):
schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
self.schema = schema
self.columns_list = columns_list
self.num_samples = num_samples
if schema_obj is not None and num_samples is None:
self.num_samples = schema_obj.num_rows
def get_args(self):
args = super().get_args()
if self.schema is not None:
if isinstance(self.schema, Schema):
self.schema.datasetType = 'Random'
if self.num_samples is not None:
self.schema.num_rows = self.num_samples
args["schema_json_string"] = self.schema.to_json()
else:
args["schema_file_path"] = self.schema
args["schema"] = self.schema
if self.columns_list is not None:
args["columns_list"] = self.columns_list
if self.num_samples is not None:
args["num_samples"] = self.num_samples
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
return num_samples
class Schema:
"""
Class to represent a schema of dataset.
......
......@@ -192,6 +192,8 @@ class Iterator:
op_type = OpName.CIFAR100
elif isinstance(dataset, de.CelebADataset):
op_type = OpName.CELEBA
elif isinstance(dataset, de.RandomDataset):
op_type = OpName.RANDOMDATA
elif isinstance(dataset, de.TextFileDataset):
op_type = OpName.TEXTFILE
else:
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include <memory>
#include <vector>
#include <iostream>
#include "dataset/core/tensor_shape.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/data_schema.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestRandomDataOp : public UT::DatasetOpTesting {
};
// Test info:
// - Simple test with a user-provided schema generated purely from DataSchema C API
// - has an interation loop
//
// Tree: single node tree with RandomDataOp
//
// RandomDataOp
//
TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic1) {
Status rc;
int32_t rank = 0; // not used
MS_LOG(INFO) << "UT test RandomDataOpBasic1";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
// Create a schema using the C api's
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
// RandomDataOp can randomly fill in unknown dimension lengths of a shape.
// Most other ops cannot do that as they are limited by the physical data itself. We're
// more flexible with random data since it is just making stuff up on the fly.
TensorShape c1Shape({TensorShape::kDimUnknown, TensorShape::kDimUnknown, 3});
ColDescriptor c1("image",
DataType(DataType::DE_INT8),
TensorImpl::kFlexible,
rank, // not used
&c1Shape);
// Column 2 will just be a scalar label number
TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor c2("label",
DataType(DataType::DE_UINT32),
TensorImpl::kFlexible,
rank,
&c2Shape);
testSchema->AddColumn(c1);
testSchema->AddColumn(c2);
std::shared_ptr<RandomDataOp> myRandomDataOp;
RandomDataOp::Builder builder;
rc = builder.SetRowsPerBuffer(2)
.SetNumWorkers(1)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(25)
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
std::ostringstream ss;
ss << *myRandomDataOp;
MS_LOG(INFO) << "RandomDataOp print: %s" << ss.str();
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows...too big to show
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 25);
}
// Test info:
// - Simple test with a randomly generated schema
// - no iteration loop on this one, just create the op
//
// Tree: single node tree with RandomDataOp
//
// RandomDataOp
//
TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic2) {
Status rc;
MS_LOG(INFO) << "UT test RandomDataOpBasic2";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
std::shared_ptr<RandomDataOp> myRandomDataOp;
RandomDataOp::Builder builder;
rc = builder.SetRowsPerBuffer(2)
.SetNumWorkers(1)
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
std::ostringstream ss;
ss << *myRandomDataOp;
MS_LOG(INFO) << "RandomDataOp print: " << ss.str();
}
// Test info:
// - json file test with iteration
//
// Tree: single node tree with RandomDataOp
//
// RandomDataOp
//
TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic3) {
Status rc;
MS_LOG(INFO) << "UT test RandomDataOpBasic3";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
rc = testSchema->LoadSchemaFile(datasets_root_path_ + "/testRandomData/datasetSchema.json", {});
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<RandomDataOp> myRandomDataOp;
RandomDataOp::Builder builder;
rc = builder.SetRowsPerBuffer(2)
.SetNumWorkers(1)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(10)
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
std::ostringstream ss;
ss << *myRandomDataOp;
MS_LOG(INFO) << "RandomDataOp print: " << ss.str();
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
// Don't display these rows...too big to show
MS_LOG(INFO) << "Row fetched #: " << rowCount;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 10);
}
// Test info:
// - json schema input it's a fairly simple one
// - has an interation loop
//
// Tree: RepeatOp over RandomDataOp
//
// RepeatOp
// |
// RandomDataOp
//
TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) {
Status rc;
MS_LOG(INFO) << "UT test RandomDataOpBasic4";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
rc = testSchema->LoadSchemaFile(datasets_root_path_ + "/testRandomData/datasetSchema2.json", {});
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<RandomDataOp> myRandomDataOp;
RandomDataOp::Builder builder;
rc = builder.SetRowsPerBuffer(2)
.SetNumWorkers(1)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(10)
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
uint32_t numRepeats = 2;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats)
.Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myRepeatOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
MS_LOG(INFO) << "Row display for row #: " << rowCount;
// Display the tensor by calling the printer on it
for (int i = 0; i < tensorList.size(); i++) {
std::ostringstream ss;
ss << *tensorList[i] << std::endl;
MS_LOG(INFO) << "Tensor print: %s" << ss.str();
}
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 20);
}
// Test info:
// - json schema input it's a fairly simple one
// - has an interation loop
// - same as MindDataTestRandomDataOpBasic4 except that this one will have parallel workers
//
// Tree: RepeatOp over RandomDataOp
//
// RepeatOp
// |
// RandomDataOp
//
TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic5) {
Status rc;
MS_LOG(INFO) << "UT test RandomDataOpBasic5";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
rc = testSchema->LoadSchemaFile(datasets_root_path_ + "/testRandomData/datasetSchema2.json", {});
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<RandomDataOp> myRandomDataOp;
RandomDataOp::Builder builder;
rc = builder.SetRowsPerBuffer(2)
.SetNumWorkers(4)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(10)
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
uint32_t numRepeats = 3;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats)
.Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myRepeatOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
MS_LOG(INFO) << "Row display for row #: " << rowCount;
// Display the tensor by calling the printer on it
for (int i = 0; i < tensorList.size(); i++) {
std::ostringstream ss;
ss << *tensorList[i] << std::endl;
MS_LOG(INFO) << "Tensor print: ", ss.str();
}
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 30);
}
// Test info:
// - repeat shuffle random
//
// Tree: RepeatOp over RandomDataOp
//
// RepeatOp
// |
// ShuffleOp
// |
// RandomDataOp
//
TEST_F(MindDataTestRandomDataOp, RandomDataOpTree1) {
Status rc;
MS_LOG(INFO) << "UT test RandomDataOpTree1";
// Start with an empty execution tree
auto myTree = std::make_shared<ExecutionTree>();
std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>();
rc = testSchema->LoadSchemaFile(datasets_root_path_ + "/testRandomData/datasetSchema2.json", {});
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<RandomDataOp> myRandomDataOp;
RandomDataOp::Builder builder;
rc = builder.SetRowsPerBuffer(2)
.SetNumWorkers(4)
.SetDataSchema(std::move(testSchema))
.SetTotalRows(10)
.Build(&myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<ShuffleOp> myShuffleOp;
rc = ShuffleOp::Builder()
.SetRowsPerBuffer(2)
.SetShuffleSize(4)
.Build(&myShuffleOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myShuffleOp);
EXPECT_TRUE(rc.IsOk());
uint32_t numRepeats = 3;
std::shared_ptr<RepeatOp> myRepeatOp;
rc = RepeatOp::Builder(numRepeats)
.Build(&myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myRepeatOp->AddChild(myShuffleOp);
EXPECT_TRUE(rc.IsOk());
rc = myShuffleOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = myTree->Launch();
EXPECT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator dI(myTree);
TensorRow tensorList;
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
int rowCount = 0;
while (!tensorList.empty()) {
MS_LOG(INFO) << "Row display for row #: " << rowCount;
// Display the tensor by calling the printer on it
for (int i = 0; i < tensorList.size(); i++) {
std::ostringstream ss;
ss << *tensorList[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str();
}
rc = dI.FetchNextTensorRow(&tensorList);
EXPECT_TRUE(rc.IsOk());
rowCount++;
}
ASSERT_EQ(rowCount, 30);
}
{
"columns": {
"image": {
"type": "uint8",
"rank": 3,
"shape": [1920,1080,3]
},
"label": {
"type": "int32",
"rank": 1,
"shape": [1]
}
}
}
{
"columns": {
"image": {
"type": "uint8",
"rank": 2,
"shape": [28,28]
},
"label": {
"type": "uint8",
"rank": 1,
"shape": [1]
}
}
}
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from pathlib import Path
# just a basic test with parallel random data op
def test_randomdataset_basic1():
print("Test randomdataset basic")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8, shape=[2])
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# apply dataset operations
ds1 = ds.RandomDataset(schema=schema, num_samples=50, num_parallel_workers=4)
ds1 = ds1.repeat(4)
num_iter = 0
for data in ds1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
print("{} image: {}".format(num_iter, data["image"]))
print("{} label: {}".format(num_iter, data["label"]))
num_iter += 1
print("Number of data in ds1: ", num_iter)
assert(num_iter == 200)
# Another simple test
def test_randomdataset_basic2():
print("Test randomdataset basic 2")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8, shape=[640,480,3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
# Make up about 10 samples
ds1 = ds.RandomDataset(schema=schema, num_samples=10, num_parallel_workers=1)
# cache size allows for about 4 images since each image just a bit less than 1MB, after that we will have to spill
ds1 = ds1.repeat(4)
num_iter = 0
for data in ds1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
#print(data["image"])
print("printing the label: {}".format(data["label"]))
num_iter += 1
print("Number of data in ds1: ", num_iter)
assert(num_iter == 40)
if __name__ == '__main__':
test_randomdataset_basic1()
test_randomdataset_basic2()
print('test_randomdataset_basic Ended.\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册