提交 cf82aa90 编写于 作者: M ms_yan

init remove storage op in c++

init remove storage op test case c++

remove source c++ files
上级 2d50a43b
......@@ -48,7 +48,6 @@ namespace dataset {
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *);
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kStorage, &DEPipeline::ParseStorageOp},
{kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp},
......@@ -301,70 +300,6 @@ Status DEPipeline::SetBatchParameters(const py::dict &args) {
return Status::OK();
}
Status DEPipeline::ValidateArgStorageOp(const py::dict &args) {
// Required arguments
if (((args.contains("dataset_files") && args["dataset_files"].is_none()) || args["schema"].is_none()) &&
((args.contains("dataset_dir") && args["dataset_dir"].is_none()) ||
(args["schema"].is_none() && args["schema_json_string"].is_none()))) {
std::string err_msg = "Error: at least one of dataset_files or schema_file is missing";
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
Status DEPipeline::ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
RETURN_IF_NOT_OK(ValidateArgStorageOp(args));
std::shared_ptr<StorageOp::Builder> builder;
if (args.contains("dataset_files") && !args["dataset_files"].is_none()) {
builder = std::make_shared<StorageOp::Builder>();
(void)builder->SetDatasetFileList(ToStringVector(args["dataset_files"]));
(void)builder->SetSchemaFile(ToString(args["schema"]));
} else if (args.contains("dataset_dir") && !args["dataset_dir"].is_none()) {
builder = std::make_shared<StorageOp::Builder>();
(void)builder->SetDatasetFilesDir(ToString(args["dataset_dir"]));
if (!args["schema"].is_none()) {
(void)builder->SetSchemaFile(ToString(args["schema"]));
} else if (!args["schema_json_string"].is_none()) {
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
std::string s = ToString(args["schema_json_string"]);
RETURN_IF_NOT_OK(schema->LoadSchemaString(s, std::vector<std::string>()));
(void)builder->SetNumRows(schema->num_rows());
(void)builder->SetSchema(std::move(schema));
}
}
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "prefetch_size") {
(void)builder->SetOpConnectorSize(ToInt(value));
} else if (key == "columns_list") {
(void)builder->SetColumnsToLoad(ToStringVector(value));
} else if (key == "distribution") {
(void)builder->SetDataDistributionFile(ToString(value));
} else if (key == "labels_filename") {
(void)builder->setLabelsFileName(ToString(value));
} else if (key == "dataset_usage") {
(void)builder->SetDatasetUsage(ToString(value));
}
}
}
(void)builder->SetBatchSize(temp_batch_size_);
(void)builder->SetDropRemainder(temp_drop_remainder_);
std::shared_ptr<StorageOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
num_rows_ = op->num_rows();
num_classes_ = op->num_classes();
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
if (!args["buffer_size"].is_none()) {
......
......@@ -37,7 +37,6 @@ using DsOpPtr = std::shared_ptr<DatasetOp>;
// enum for the dataset operator names
enum OpName {
kStorage = 0,
kShuffle,
kMindrecord,
kBatch,
......@@ -105,8 +104,6 @@ class DEPipeline {
int GetRepeatCount() const;
Status ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
......@@ -181,9 +178,6 @@ class DEPipeline {
std::unique_ptr<DatasetIterator> iterator_;
// Validate required args passed to storage op.
Status ValidateArgStorageOp(const py::dict &args);
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
int batch_size_;
......
......@@ -826,7 +826,6 @@ PYBIND11_MODULE(_c_dataengine, m) {
(void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");
(void)py::enum_<OpName>(m, "OpName", py::arithmetic())
.value("STORAGE", OpName::kStorage)
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BUCKETBATCH", OpName::kBucketBatch)
......
......@@ -39,7 +39,6 @@
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h"
......
......@@ -17,8 +17,6 @@
#include "dataset/util/allocator.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/datasetops/source/storage_client.h"
#include "dataset/engine/datasetops/source/tf_buffer.h"
namespace mindspore {
namespace dataset {
......@@ -26,37 +24,6 @@ namespace dataset {
// Description: This is the main constructor that is used for making a buffer
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
// Name: CreateDataBuffer()
// Description: A static factory method to create the appropriate type of derived class
// buffer. Returns the base class reference for DataBuffer.
Status DataBuffer::CreateDataBuffer(
int32_t id, // In: The id for the new buffer
std::shared_ptr<StorageClient> storage_client, // In: The storage client that is related to this buffer type
std::unique_ptr<DataBuffer> *ptr) {
std::unique_ptr<DataBuffer> new_data_buffer;
try {
DatasetType ds_type = storage_client->schema()->dataset_type();
switch (ds_type) {
case DatasetType::kTf: {
// This type of buffer is for TF record data.
// Allocate derived class version for a TF buffers
new_data_buffer = std::make_unique<TFBuffer>(id, kDeBFlagNone, storage_client);
break;
}
default: {
std::string errMsg("Invalid buffer type");
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
} catch (std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what());
} catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
*ptr = std::move(new_data_buffer);
return Status::OK();
}
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
void DataBuffer::Print(std::ostream &out, // In: The output stream to print to
......
......@@ -29,9 +29,6 @@
namespace mindspore {
namespace dataset {
// Forward declares
class StorageClient;
// The DataBuffer class is a base class that will represent the data for n values based
// on a unique row id for each row of data.
// There can be different types of DataBuffers to abstract over how the data is stored
......@@ -53,14 +50,6 @@ class DataBuffer {
// Destructor
virtual ~DataBuffer();
// Name: CreateDataBuffer()
// Description: A factory method to create the appropriate type of derived class
// buffer. Returns the base class reference for DataBuffer.
static Status CreateDataBuffer(
int32_t id, // In: The id for the new buffer
std::shared_ptr<StorageClient>, // In: The StorageClient is used to choose the buffer type to create
std::unique_ptr<DataBuffer> *);
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
virtual void Print(std::ostream &out, // In: The output stream to print to
......
......@@ -53,7 +53,7 @@ class IteratorBase {
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
// @return Status - The error code return
// @note The position of a Tensor/column might be different from the initial column order
// in the storageOp. User must be aware that MapOp, ZipOps, and others might change
// in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change
// the column ordering.
virtual Status FetchNextTensorRow(TensorRow *out_row);
......
......@@ -40,7 +40,7 @@ class ConcatOp : public PipelineOp {
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
// @return shared_ptr to the new ConcatOp object
Status Build(std::shared_ptr<ConcatOp> *);
private:
......
......@@ -40,7 +40,7 @@ class ProjectOp : public PipelineOp {
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object.
// @return shared_ptr to the new ProjectOp object.
Status Build(std::shared_ptr<ProjectOp> *);
private:
......
......@@ -67,7 +67,7 @@ class RenameOp : public PipelineOp {
}
// The builder "build" method creates the ZipOp dataset Operator.
// @return shared_ptr to the new StorageOp object
// @return shared_ptr to the new RenameOp object
Status Build(std::shared_ptr<RenameOp> *);
private:
......
......@@ -42,7 +42,7 @@ class RepeatOp : public PipelineOp {
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
// @return shared_ptr to the new RepeatOp object
Status Build(std::shared_ptr<RepeatOp> *);
private:
......
......@@ -101,7 +101,7 @@ class ShuffleOp : public PipelineOp {
}
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
// @return shared_ptr to the new ShuffleOp object
Status Build(std::shared_ptr<ShuffleOp> *);
private:
......
......@@ -37,7 +37,7 @@ class SkipOp : public PipelineOp {
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
// @return shared_ptr to the new SkipOp object
Status Build(std::shared_ptr<SkipOp> *);
private:
......
......@@ -5,10 +5,6 @@ add_library(engine-datasetops-source OBJECT
generator_op.cc
io_block.cc
mindrecord_op.cc
storage_client.cc
storage_op.cc
tf_buffer.cc
tf_client.cc
tf_reader_op.cc
image_folder_op.cc
mnist_op.cc
......
......@@ -25,7 +25,7 @@
namespace mindspore {
namespace dataset {
GeneratorOp::Builder::Builder() {
// Some arguments to the StorageOp constructor have a default argument that is taken
// Some arguments to the GeneratorOp constructor have a default argument that is taken
// from the client config.
build_buffer_size_ = kCfgRowsPerBuffer;
build_op_connector_size_ = kCfgOpConnectorSize;
......
......@@ -72,7 +72,7 @@ class GeneratorOp : public PipelineOp {
}
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
// @return shared_ptr to the new GeneratorOp object
Status Build(std::shared_ptr<GeneratorOp> *);
private:
......
......@@ -198,7 +198,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;
// This function is a hack! It is to return the num_class and num_rows the old storageOp does. The result
// This function is a hack! It is to return the num_class and num_rows. The result
// returned by this function may not be consistent with what image_folder_op is going to return
// user this at your own risk!
static Status CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
......
......@@ -44,7 +44,7 @@ using mindrecord::ShardReader;
MindRecordOp::Builder::Builder() : build_dataset_file_({}) {
// Some arguments to the MindRecordOp constructor have a default argument that is taken
// from the client config.
// The user may choose to change these values for the construction of the StorageOp by
// The user may choose to change these values for the construction of the MindRecordOp by
// using the various builder set methods.
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
......
......@@ -45,7 +45,7 @@ class PythonSampler : public Sampler {
Status ResetSampler() override;
// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
......
......@@ -38,7 +38,7 @@ class RandomAccessOp {
// @return - The error code return
Status GetNumRowsInDataset(int64_t *num_rows) const;
// sampler gets label , imageIds from storageOp, this function is unique to PK
// sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK
// @param std::map<int64_t, std::vector<int64_t>> * map
// @return - The error code return
virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const {
......
......@@ -44,7 +44,7 @@ class SequentialSampler : public Sampler {
Status ResetSampler() override;
// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define MAX_INTEGER_INT32 2147483647
#include <iostream>
#include <memory>
#include <utility>
#include <nlohmann/json.hpp>
#include "dataset/core/constants.h"
#include "dataset/engine/datasetops/source/storage_client.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_client.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Name: Constructor
// Description:
StorageClient::StorageClient(std::unique_ptr<DataSchema> schema, // In: The schema for this storage client.
StorageOp *store_op) // In: The StorageOp that's using this client
: data_schema_(std::move(schema)), num_rows_in_dataset_(0), storage_op_(store_op), num_classes_(0) {}
// Name: Print()
// Description: A function that prints info about the StorageClient
// In: The output stream to print to
void StorageClient::Print(std::ostream &out) const {
// not much to show here folks!
// out << "Storage client:\n";
}
// This is a local-only static function to drive the switch statement for creating
// the storage client (not a static member function)
static Status CreateStorageClientSwitch(
std::unique_ptr<DataSchema> schema, // In: The schema to set into the client
StorageOp *store_op, // In: The StorageOp we are operating on
std::shared_ptr<StorageClient> *out_client) { // Out: the created storage client
switch (schema->dataset_type()) {
case DatasetType::kArrow: {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Storage client not implemented yet for arrow dataset type.");
}
case DatasetType::kTf: {
// Construct the derived class TFClient, stored as base class StorageClient
store_op->set_rows_per_buffer(32);
*out_client = std::make_unique<TFClient>(std::move(schema), store_op);
break;
}
case DatasetType::kUnknown:
default: {
RETURN_STATUS_UNEXPECTED("Invalid dataset type.");
}
}
if (*out_client) {
RETURN_IF_NOT_OK((*out_client)->Init());
}
return Status::OK();
}
// Name: CreateStorageClient()
// Description: A factory method to create the derived storage client.
// Every dataset has a required field for the dataset type in a config
// file. This type will determine the child class to return for the
// type of storage client. It also creates the schema and sticks it
// into the cache.
Status StorageClient::CreateStorageClient(
StorageOp *store_op, // In: A backpointer to the owning cache for this client.
std::string dataset_schema_path, // In: The path to the schema
std::shared_ptr<StorageClient> *out_client) { // Out: the created storage client
// Make a new schema first. This only assigns the dataset type. It does not
// create the columns yet.
auto new_schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(new_schema->LoadDatasetType(dataset_schema_path));
RETURN_IF_NOT_OK(CreateStorageClientSwitch(std::move(new_schema), store_op, out_client));
return Status::OK();
}
// Name: CreateStorageClient()
// Description: A factory method to create the derived storage client.
// This creator is a user-override for the schema properties where
// the user has input the layout of the data (typically used in testcases)
Status StorageClient::CreateStorageClient(
StorageOp *store_op, // In: A backpointer to the owning cache for this client.
DatasetType in_type, // In: The type of dataset
std::shared_ptr<StorageClient> *out_client) { // Out: the created storage client
// The dataset type is passed in by the user. Create an empty schema with only
// only the dataset type filled in and then create the client with it.
auto new_schema = std::make_unique<DataSchema>();
new_schema->set_dataset_type(in_type);
RETURN_IF_NOT_OK(CreateStorageClientSwitch(std::move(new_schema), store_op, out_client));
return Status::OK();
}
// Name: LoadDatasetLayout()
// Description: There are 2 ways to define the properties of the data in the storage
// layer: LoadDatasetLayout() and AssignDatasetLayout().
// LoadDatasetLayout() will parse the json config file that comes with
// the dataset.
Status StorageClient::LoadDatasetLayout() {
// Access the json file to populate our schema, assume the json file is accessible
// locally.
RETURN_IF_NOT_OK(data_schema_->LoadSchemaFile(storage_op_->schema_file(), storage_op_->columns_to_load()));
// The number of rows in the schema file is an optional config. For example,
// maybe the derived storage client will know how to determine the total number
// of rows a different way rather than having it in the schema config json file.
// Thus, mNumRowsInDataset can still be zero and force the derived class override
// to determine it another way.
uint32_t num_rows = 0;
RETURN_IF_NOT_OK(this->numRowsFromFile(num_rows));
CHECK_FAIL_RETURN_UNEXPECTED(num_rows <= MAX_INTEGER_INT32, "numRows exceeds the boundary numRows>2147483647");
if (num_rows_in_dataset_ == 0 || num_rows < num_rows_in_dataset_) {
num_rows_in_dataset_ = num_rows;
}
return Status::OK();
}
// Name: AssignDatasetLayout()
// Description: There are 2 ways to define the properties of the data in the storage
// layer: LoadDatasetLayout() and AssignDatasetLayout().
// AssignDatasetLayout() will take input from the caller and assign that
// info into the storage client.
Status StorageClient::AssignDatasetLayout(uint32_t num_rows, // In: The number of rows in the dataset
const DataSchema &schema) { // In: The schema for the dataset
// Since this is just an assignment into the storage client, you probably won't need
// to override this one in a derived class. First some sanity checks
CHECK_FAIL_RETURN_UNEXPECTED(data_schema_->dataset_type() == schema.dataset_type(),
"Assigning a schema into StorageClient with mismatched dataset types!");
CHECK_FAIL_RETURN_UNEXPECTED(data_schema_->NumColumns() == 0,
"Assigning a schema into StorageClient that already has non-empty schema!");
// The current schema was just an empty one with only the dataset field populated.
// Let's copy construct a new one that will be a copy of the input schema (releasing the old
// one) and then set the number of rows that the user requested.
data_schema_ = std::make_unique<DataSchema>(schema);
CHECK_FAIL_RETURN_UNEXPECTED(num_rows <= MAX_INTEGER_INT32, "numRows exceeds the boundary numRows>2147483647");
num_rows_in_dataset_ = num_rows;
return Status::OK();
}
// Name: numRowsFromFile()
// Description: Reads the schema json file to see if the optional numRows field has
// been set and returns it.
Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
std::string schemaFile = storage_op_->schema_file();
try {
std::ifstream in(schemaFile);
nlohmann::json js;
in >> js;
if (js.find("numRows") == js.end()) {
num_rows = MAX_INTEGER_INT32;
} else {
num_rows = js.value("numRows", 0);
}
if (num_rows == 0) {
std::string err_msg =
"Storage client has not properly done dataset "
"handshake to initialize schema and number of rows.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
// Catch any exception and rethrow it as our own
catch (const std::exception &err) {
std::ostringstream ss;
ss << "Schema file failed to load:\n" << err.what();
std::string err_msg = ss.str();
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
// Get'r function
DataSchema *StorageClient::schema() const { return data_schema_.get(); }
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_CLIENT_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_CLIENT_H_
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "dataset/engine/data_schema.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
// The Storage Client is the interface and base class that the StorageOp
// will use to perform any interactions with the storage layer.
// The different types of datasets will have different derived classes
// under that storage client super class.
class StorageClient {
public:
// Name: Constructor
// Description:
StorageClient(std::unique_ptr<DataSchema> schema, // In: The schema for this storage client.
StorageOp *store_op); // In: The StorageOp that's using this client
// Destructor
virtual ~StorageClient() { storage_op_ = nullptr; }
virtual Status Init() { return Status::OK(); }
// Name: CreateStorageClient()
// Description: A factory method to create the derived storage client.
// Every dataset has a required field for the dataset type in a config
// file. This type will determine the child class to return for the
// type of storage client.
static Status CreateStorageClient(StorageOp *store_op, // In: A backpointer to the owning storage op for this client.
std::string dataset_schema_path, // In: The path to the dataset
std::shared_ptr<StorageClient> *out_client); // Out: the created storage client
// Name: CreateStorageClient()
// Description: A factory method to create the derived storage client.
// This creator is a user-override for the schema properties where
// the user has input the layout of the data (typically used in testcases)
static Status CreateStorageClient(StorageOp *store_op, // In: A backpointer to the owning cache for this client.
DatasetType in_type, // In: The type of dataset
std::shared_ptr<StorageClient> *out_client); // Out: the created storage client
// Name: Print()
// Description: A function that prints info about the StorageClient
virtual void Print(std::ostream &out) const; // In: The output stream to print to
// Provide stream operator for displaying
friend std::ostream &operator<<(std::ostream &out, const StorageClient &storage_client) {
storage_client.Print(out);
return out;
}
// Name: LoadDatasetLayout()
// Description: There are 2 ways to define the properties of the data in the storage
// layer: LoadDatasetLayout() and AssignDatasetLayout().
// LoadDatasetLayout() will parse the json config file that comes with
// the dataset and internally populate row counts and schema.
virtual Status LoadDatasetLayout();
// Name: AssignDatasetLayout()
// Description: There are 2 ways to define the properties of the data in the storage
// layer: LoadDatasetLayout() and AssignDatasetLayout().
// AssignDatasetLayout() will take input from the caller and assign that
virtual Status AssignDatasetLayout(uint32_t num_rows, // In: The number of rows in the dataset
const DataSchema &schema); // In: The schema for the dataset
// Name: Reset()
// Description: Resets any state info inside the client back to it's initialized
// state.
virtual Status Reset() = 0;
// Name: IsMoreData
// Description: General routine to ask if more data exists in the storage side for
// a given buffer id.
virtual bool IsMoreData(uint32_t id) { return true; }
// Name: numRowsFromFile()
// Description: Reads the schema json file to see if the optional numRows field has
// been set and returns it.
Status numRowsFromFile(uint32_t &num_rows) const;
// Get'r functions
DataSchema *schema() const;
uint32_t num_rows() const { return num_rows_in_dataset_; }
// Name: rows_per_buffer()
// Description: This default version simply gives you the count of the requested
// rows per buffer that the user defined in the storage op.
// However, if some condition down in the storage client layers
// could result in a buffer that has a different number of rows,
// then the derived class can override this method to provide their
// own implementation.
virtual uint32_t rows_per_buffer() { return storage_op_->rows_per_buffer(); }
// Description: Get the label classes num. Only manifest and Imagenet dataset support this parameter
virtual uint32_t num_classes() const { return 0; }
protected:
std::unique_ptr<DataSchema> data_schema_; // The schema for the data
uint32_t num_rows_in_dataset_; // The number of rows in the dataset
StorageOp *storage_op_; // Back pointer to the owning storage operator.
std::vector<std::string> col_names_;
uint32_t num_classes_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_CLIENT_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_OP_H_
#include <condition_variable>
#include <cstdint>
#include <map>
#include <memory>
#include <mutex>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/data_schema.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Forward declares
template <typename T>
class Queue;
// A type for a container of DataBuffer shared_ptr's
using DataBuffers = std::vector<std::unique_ptr<DataBuffer>>;
// A type for the queue of buffer id's for workers to fetch.
using ActionQueue = std::vector<std::unique_ptr<Queue<int32_t>>>;
// Forward declare
class DataBuffer;
class StorageClient;
class StorageOp : public ParallelOp {
public:
// The nested builder class inside of the StorageOp is used to help manage all of the arguments
// for constructing it. Use the builder by setting each argument with the provided set methods,
// and then finally call the build method to execute the actual construction.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetNumRows(int num_rows) {
build_num_rows_ = num_rows;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int rows_per_buffer) {
build_rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetSchema(std::unique_ptr<DataSchema> schema) {
build_schema_ = std::move(schema);
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetWorkerConnectorSize(int32_t connector_size) {
build_worker_connector_size_ = connector_size;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetSchemaDir(const std::string &schema_dir) {
build_schema_file_ = schema_dir + "/datasetSchema.json";
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetSchemaFile(const std::string &schema_file) {
build_schema_file_ = schema_file;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetDatasetFilesDir(const std::string &files_dir) {
build_dataset_files_dir_ = files_dir;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetDatasetFileList(const std::vector<std::string> &file_list) {
build_dataset_file_list_ = file_list;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetColumnsToLoad(const std::vector<std::string> &columns) {
build_columns_to_load_ = columns;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetDataDistributionFile(const std::string &data_distribution_file) {
build_data_distribution_file_ = data_distribution_file;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &setLabelsFileName(const std::string &labels_file_name) {
build_labels_file_name_ = labels_file_name;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetDatasetUsage(const std::string &dataset_usage) {
build_dataset_usage_ = dataset_usage;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetBatchSize(int32_t batch_size) {
build_batch_size_ = batch_size;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetDropRemainder(bool drop_remainder) {
build_drop_remainder_ = drop_remainder;
return *this;
}
// The builder "build" method creates the final object.
// @param shared_ptr to the new StorageOp object
// @return Status - The error code return
Status Build(std::shared_ptr<StorageOp> *);
private:
// The builder saves all StorageOp construction arguments internally.
// The following are the arguments.
std::string build_dataset_files_dir_;
std::string build_schema_file_;
int32_t build_num_rows_;
std::string build_data_distribution_file_;
int32_t build_rows_per_buffer_;
int32_t build_worker_connector_size_;
int32_t build_num_workers_;
int32_t build_op_connector_size_;
std::unique_ptr<DataSchema> build_schema_;
std::vector<std::string> build_dataset_file_list_;
std::vector<std::string> build_columns_to_load_;
std::string build_labels_file_name_;
std::string build_dataset_usage_;
int32_t build_batch_size_;
bool build_drop_remainder_;
};
// Constructor of the StorageOp.
// @note The builder class should be used to call it
// @param num_workers - The number of workers for the op
// @param worker_connector_size - The internal connector size between workers and master
// @param rows_per_buffer - The requested number of rows per buffer
// @param op_connector_size - The output connector queue size
// @param columns_to_load - The list of columns to use (column name)
StorageOp(int32_t num_workers, int32_t worker_connector_size, int32_t rows_per_buffer, int32_t op_connector_size,
std::vector<std::string> columns_to_load, std::string data_distribution_file, int32_t batch_size,
bool drop_remainder);
// Init the StorageOp. This is 1 of 3 init.
// This version of the init does not take the schema in it's arguments. It must perform an
// internal handshake with the dataset to produce the schema.
// @note The builder class should be used to call it
// @param dataset_files_dir - The directory that has the dataset files
// @param schema_file - The schema file for providing column info
Status InitOp(const std::string &dataset_files_dir, const std::string &schema_file,
const std::string &labels_file_name, const std::string &dataset_usage);
// Init the StorageOp. This is 2 of 3 init.
// This version of the init allows the user to input the schema and other dataset properties rather
// than get it from the dataset itself.
// @note The builder class should be used to call it
// @param num_rows - The number of rows in the dataset
// @param dataset_files_dir - The directory that has the dataset files
// @param data_schema - The schema to use
Status InitOp(int32_t num_rows, const std::string &dataset_files_dir, std::unique_ptr<DataSchema> data_schema,
const std::string &labels_file_name, const std::string &dataset_usage);
// Init the StorageOp. This is 3 of 3 init.
// This version of the init does not take the schema in it's arguments. It must perform an
// internal handshake with the dataset to produce the schema. Unlike constructor 1, it takes a
// list of files rather than a directory.
// @note The builder class should be used to call it
// @param files_list - The list of files to use for the dataset
// @param schema_file - The schema file for providing column info
Status InitOp(const std::vector<std::string> &files_list, const std::string &schema_file);
// Destructor
~StorageOp();
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param storage_op - reference to the StorageOp to display
// @return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const StorageOp &storage_op) {
storage_op.Print(out, false);
return out;
}
// Class functor operator () override.
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work.
// @return Status - The error code return
Status operator()() override;
// The entry point code for when workers are launched.
// @param worker_id - The worker id
// @return Status - The error code return
Status WorkerEntry(int32_t worker_id) override;
// The entry point code for when workers are launched.
// Given the input bufferId, it returns a shared_ptr to that buffer back to you by driving a
// load operation. This function is intended to be run by worker threads, when they are
// populating the memory with the actual data of the buffer.
// @param buffer_id - The buffer id to get.
// @param ptr - Pointer to shared_ptr to the buffer that was loaded in.
// @return Status - The error code return
Status GetBuffer(int32_t buffer_id, std::unique_ptr<DataBuffer> *ptr);
// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
// @return Status - The error code return
Status Reset() override;
// Getter method
int32_t num_rows() const { return num_rows_; }
// Setter method
void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
// Getter method
int32_t rows_per_buffer() const { return rows_per_buffer_; }
// Setter method
void set_rows_per_buffer(int32_t rows_per_buffer) { rows_per_buffer_ = rows_per_buffer; }
// Getter method
std::string dataset_files_dir() const { return dataset_files_dir_; }
// Getter method
std::vector<std::string> dataset_file_list() const { return dataset_file_list_; }
// Getter method
std::string schema_file() const { return schema_file_; }
// Getter method
const DataSchema *schema() const;
// Getter method
const std::vector<std::string> columns_to_load() const { return columns_to_load_; }
// Getter method
std::string data_distribution_file() const { return data_distribution_file_; }
// Getter method
int32_t device_num() const { return device_num_; }
// Getter method
int32_t device_id() const { return device_id_; }
// Getter method
std::string shard_config() const { return shard_config_; }
// Getter method
uint32_t seed() const { return seed_; }
// Getter method
bool shuffle_config() const { return shuffle_config_; }
// Getter method
int32_t num_classes() const { return num_classes_; }
// Getter method
std::string labels_file_name() const { return labels_file_name_; }
// Getter method
std::string dataset_usage() const { return dataset_usage_; }
// Getter method
int32_t batch_size() const { return batch_size_; }
// Getter method
bool drop_remainder() const { return drop_remainder_; }
private:
// Private helper method. This one populates the action queue with the list of buffer ids.
// @param randomize - T/F if the id's in the action queue should be randomized or sequential.
Status FillActionQueue(bool randomize);
// Private helper method. This one encapsulates some common construction/reset tasks and is
// designed to be re-entrant so that you can re-init a previously used StorageOp without needing
// to redo the storage client handshake.
// @return Status - The error code return
Status init();
// Private helper method. This one posts a control indicator for each worker thread to consume
// from the action queue. When the worker pops this msg, it will shut itself down gracefully.
// @return Status - The error code return
Status PostEndOfData();
Status LoadParallelConfig();
DataBuffers data_buffers_; // A vector of pointers to buffers
std::shared_ptr<StorageClient> store_client_; // The client for interacting with storage
ActionQueue action_queue_; // The queues of buffer id's for workers to fetch.
int32_t worker_conn_size_; // connector size for internal worker queue
int32_t rows_per_buffer_; // The number of requested rows per buffer.
int32_t num_rows_; // One more than the last row id in the range for this cache
std::string dataset_files_dir_; // The path for the dataset files
std::vector<std::string> dataset_file_list_; // List of paths to files for the dataset
int32_t buffers_fetched_; // Counter for the buffers that were fetched
std::string schema_file_; // Path to the schema json file
std::vector<std::string> columns_to_load_; // Columns to load from dataset
std::string data_distribution_file_; // Distribution configuration file
int32_t device_num_; // All device number
int32_t device_id_; // Device id
std::string shard_config_; // ALL UNIQUE RANDOM
uint32_t seed_; // Used for shuffle
bool shuffle_config_; // True or false
std::string labels_file_name_; // File name of labels
int32_t num_classes_; // Label class number
std::string dataset_usage_; // train/eval/inference
int32_t batch_size_;
bool drop_remainder_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_OP_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/tf_buffer.h"
#include <cstring>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include "common/utils.h"
#include "utils/log_adapter.h"
#include "dataset/engine/datasetops/source/tf_client.h"
#include "dataset/core/data_type.h"
#include "dataset/engine/datasetops/source/storage_client.h"
#include "dataset/engine/data_schema.h"
namespace mindspore {
namespace dataset {
// constructor
TFBuffer::TFBuffer(
uint32_t id, // In: The id for this buffer
BufferFlags flags, // In: The flags for this buffer
const std::shared_ptr<StorageClient> &storage_client) // In: Storage client that is related to this buffer type
: DataBuffer(id, flags), storage_client_(storage_client) {}
// destructor
TFBuffer::~TFBuffer() {}
// Name: print()
// Description: A function that prints info
void TFBuffer::Print(std::ostream &out, // In: The output stream to print to
bool show_all) const { // In: T/F if it should print everything
out << "TFBuffer print\n";
// Call base class printer
DataBuffer::Print(out, show_all);
}
// Name: load()
// Description: populates the DataBuffer with data
// Overrides base-class method.
Status TFBuffer::Load() {
const DataSchema *the_schema = storage_client_->schema();
uint32_t num_columns = the_schema->NumColumns();
uint32_t num_rows_requested = storage_client_->rows_per_buffer();
uint32_t remaining_rows = storage_client_->num_rows() > buffer_id_ * storage_client_->rows_per_buffer()
? storage_client_->num_rows() - buffer_id_ * storage_client_->rows_per_buffer()
: 0;
if (remaining_rows < num_rows_requested) {
num_rows_requested = remaining_rows;
}
// Construct the Tensor table for this buffer.
tensor_table_ = std::make_unique<TensorQTable>();
// At each position in the tensor table, instantiate the shared pointer to it's Tensor.
uint32_t row = 0;
while (row < num_rows_requested && (cur_reader_.peek() != EOF || storage_client_->IsMoreData(buffer_id_))) {
TensorRow new_row;
// Read the data from storage into a tf_file format
dataengine::Example tf_file;
RETURN_IF_NOT_OK(ParseSingleExample(&tf_file));
for (uint32_t col = 0; col < num_columns; ++col) {
std::shared_ptr<Tensor> new_t;
const ColDescriptor current_col = the_schema->column(col);
const dataengine::Features &example_features = tf_file.features();
const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
const dataengine::Feature &column_values_list = feature_map.at(current_col.name());
const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case();
RETURN_IF_NOT_OK(LoadFeature(column_list_type, column_values_list, current_col, &new_t));
// Add the column to the current tensor row
new_row.push_back(std::move(new_t));
}
// Add the new row of tensors to the end of our tensor table
tensor_table_->push_back(new_row);
row++;
}
cur_reader_.close();
return Status::OK();
}
// Name: ParseSingleExample()
// Description: Drives the calls to TFClient for fetching the tf_file info from
// the tf_file files. Returns a single row of data from the tf_file
// files.
Status TFBuffer::ParseSingleExample(dataengine::Example *ptr) {
if (cur_reader_.peek() == EOF) {
auto client = std::dynamic_pointer_cast<TFClient>(storage_client_);
if (client == nullptr) {
std::string errMsg = "Unexpected storage client type for TFBuffer";
RETURN_STATUS_UNEXPECTED(errMsg);
}
RETURN_IF_NOT_OK(client->NextFileInfo(buffer_id_, &cur_f_info_));
cur_reader_.close();
cur_reader_.open(cur_f_info_.fileName);
// Seek to the offset
(void)cur_reader_.seekg(static_cast<std::streamsize>(cur_f_info_.startOffset));
MS_LOG(DEBUG) << "got new file " << cur_f_info_.fileName << ".";
}
// one record in tf_file looks like:
// Format of a single record:
// uint64 length
// uint32 masked crc of length
// byte data[length]
// uint32 masked crc of data
// read length
if (cur_reader_.peek() == EOF) {
MS_LOG(ERROR) << "ParseSingleExample failed";
}
dataengine::Example tf_file;
try {
uint64_t record_length = 0;
(void)cur_reader_.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(uint64_t)));
// ignore crc header
(void)cur_reader_.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
// read serialized Example
std::string serialized_example;
serialized_example.resize(record_length);
(void)cur_reader_.read(&serialized_example[0], static_cast<std::streamsize>(record_length));
// ignore crc footer
(void)cur_reader_.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
if (!tf_file.ParseFromString(serialized_example)) {
std::string err_msg = "parse tf_file failed";
RETURN_STATUS_UNEXPECTED(err_msg);
}
} catch (const std::exception &err) {
std::string err_msg = "Please check if the data file is complete!";
RETURN_STATUS_UNEXPECTED(err_msg);
}
*ptr = tf_file;
return Status::OK();
}
// Name: LoadFeature()
// Description: Given the column type of the tf record and the values list,
// constructs the tensor and returns it.
Status TFBuffer::LoadFeature(const dataengine::Feature::KindCase &column_list_type,
const dataengine::Feature &column_values_list, const ColDescriptor &current_col,
std::shared_ptr<Tensor> *out_tensor) {
std::string element_str; // For staging data from protobuf deserialization
std::unique_ptr<int64_t[]> int_array; // For staging data from protobuf deserialization
std::unique_ptr<float[]> float_array; // For staging data from protobuf deserialization
const unsigned char *data_ptr = nullptr; // Generic pointer used for populating the Tensor
// This variable will point into the above staging
// variables.
uint32_t num_elements = 0; // Generic counter used for setting shape attributes
// Depending on the type of data from the tf_file, we want to extract 2 things:
// 1) A pointer to the data as a const unsigned char *
// 2) The number of elements of the data
// After those are determined, we can then build the tensor to represent this data.
switch (column_list_type) {
// CASE : TF record type: kBytesList
case dataengine::Feature::KindCase::kBytesList: {
RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &element_str));
// Get the const pointer representation of this data, and the number of elements
// (number of bytes) for this tensor.
data_ptr = reinterpret_cast<const unsigned char *>(common::SafeCStr(element_str));
num_elements = element_str.length();
break;
}
// CASE : TF record type: kFloatList
case dataengine::Feature::KindCase::kFloatList: {
RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array));
data_ptr = reinterpret_cast<const unsigned char *>(float_array.get());
break;
}
// CASE : TF record type: kInt64List
case dataengine::Feature::KindCase::kInt64List: {
RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, &num_elements, &int_array));
data_ptr = reinterpret_cast<const unsigned char *>(int_array.get());
break;
}
case dataengine::Feature::KindCase::KIND_NOT_SET: {
std::string errMsg = "tf_file column list type enum is KIND_NOT_SET";
RETURN_STATUS_UNEXPECTED(errMsg);
}
default: {
std::string errMsg = "tf_file column list type enum does not match any known DE type";
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
// At this point we have a raw pointer to the data, and we have the number of elements.
// Along with the tensor implementation type and the data type from the schema, we
// enough info to construct the Tensor for it.
TensorShape current_shape = TensorShape::CreateUnknownRankShape();
RETURN_IF_NOT_OK(CreateTensorShapeForColumn(current_col, num_elements, &current_shape));
// Now, create this tensor directly into the appropriate slot in our tensor
// table.
RETURN_IF_NOT_OK(
Tensor::CreateTensor(out_tensor, current_col.tensorImpl(), current_shape, current_col.type(), data_ptr));
return Status::OK();
}
Status TFBuffer::LoadBytesList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
std::string *element_str) {
// kBytesList can map to the following DE types ONLY!
// DE_UINT8, DE_INT8
// Must be single byte type for each element!
if (current_col.type() != DataType::DE_UINT8 && current_col.type() != DataType::DE_INT8) {
std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name();
RETURN_STATUS_UNEXPECTED(err_msg);
}
const dataengine::BytesList &bytes_list = column_values_list.bytes_list();
// A bytesList is a special case where the entire list of data can be
// deserialized into a single string. For example, it is not a list
// of bytes, it is a list of strings, where each string represents
// a list of bytes (this is different from the other cases like IntList etc)
// As such, if there is more than one string in this list, that is invalid.
if (bytes_list.value_size() > 1) {
std::string err_msg = "Bytes list contains more than one element for column: " + current_col.name();
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Extract the string that contains the bytes we need. Position 0 is the only
// valid string here.
*element_str = bytes_list.value(0);
return Status::OK();
}
Status TFBuffer::LoadFloatList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
uint32_t *num_elements, std::unique_ptr<float[]> *float_array) {
// KFloatList can only map to DE types:
// DE_FLOAT32
if (current_col.type() != DataType::DE_FLOAT32) {
std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name();
RETURN_STATUS_UNEXPECTED(err_msg);
}
const dataengine::FloatList &float_list = column_values_list.float_list();
// Identify how many values we have and then create a local array of these
// to deserialize into
*num_elements = float_list.value_size();
*float_array = std::make_unique<float[]>(*num_elements);
for (int i = 0; i < float_list.value_size(); i++) {
(*float_array)[i] = float_list.value(i);
}
return Status::OK();
}
Status TFBuffer::LoadIntList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
uint32_t *num_elements, std::unique_ptr<int64_t[]> *int_array) {
// KInt64List can only map to DE types:
// DE_UINT64, DE_INT64, DE_UINT32, DE_INT32, DE_UINT16, DE_INT16, DE_UINT8, DE_INT8
if (!(current_col.type().IsInt())) {
std::string err_msg = "Invalid datatype/rank for column label in TFBuffer.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
const dataengine::Int64List &int64_list = column_values_list.int64_list();
// Identify how many values we have and then create a local array of these
// to deserialize into
*num_elements = int64_list.value_size();
*int_array = std::make_unique<int64_t[]>(*num_elements);
for (int i = 0; i < int64_list.value_size(); i++) {
(*int_array)[i] = int64_list.value(i);
}
return Status::OK();
}
Status TFBuffer::CreateTensorShapeForColumn(const ColDescriptor &current_col, uint32_t num_elements,
TensorShape *current_shape) {
// If the shape is assigned by user, we have an assumption that the data is
// already in the appropriate format that we can copy into the Tensor as-is.
if (current_col.hasShape()) {
*current_shape = current_col.shape();
} else if (current_col.rank() == 1) {
// If shape was not given, then we support 2 possible shapes.
// 1) It's a scalar (rank 0), in which case the shape is empty but we need to flag
// it as a scalar value (empty shape but has a single value)
// 2) It's a rank 1 shape, and the dimension value for that single dimension will
// be comprised of the entire bytes-size of the input data.
*current_shape = TensorShape({num_elements});
} else if (current_col.rank() == 0) {
// Make this shape into a single value scalar.
*current_shape = TensorShape::CreateScalar();
} else if (current_col.rank() > 1) {
// All other ranks, except for 0, are invalid because we cannot guess
// what the shape will be. For example, if we have rank 3 and 12 bytes
// of data, is it shape {2,2,3} or is it {2,6,1}. We can't guess at
// the shape dimensions.
const std::string kErrMsg = "Invalid rank (rank>1) for dynamic shape construction. Specify shape in schema.";
RETURN_STATUS_UNEXPECTED(kErrMsg);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_BUFFER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_BUFFER_H_
#include <fstream>
#include <memory>
#include <string>
#include <vector>
#include "dataset/engine/data_buffer.h"
#include "proto/example.pb.h"
#include "dataset/engine/datasetops/source/tf_client.h"
namespace mindspore {
namespace dataset {
// This TFBuffer is the buffer type for dealing with tf record data.
class TFBuffer : public DataBuffer {
public:
// constructor
TFBuffer(uint32_t id, // In: The id for this buffer
DataBuffer::BufferFlags flags, // In: The flags for this buffer
const std::shared_ptr<StorageClient>
&storage_client); // In: The storage client that is related to this buffer type
// destructor
~TFBuffer() override;
// Name: print()
// Description: A function that prints info
void Print(std::ostream &out, // In: The output stream to print to
bool show_all) const override; // In: T/F if it should print everything
// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const TFBuffer &tf_buffer) {
tf_buffer.Print(out, false); // Show meta info only
return out;
}
// Name: load()
// Description: populates the DataBuffer with data.
// Overrides base-class method.
Status Load() override;
private:
std::ifstream cur_reader_;
FileInfo cur_f_info_;
std::shared_ptr<StorageClient> storage_client_; // The storage client for populating the buffer initially.
// Name: ParseSingleExample()
// Description: Drives the calls to TFClient for fetching the tf_file info from
// the tf_file files. Returns a single row of data from the tf_file
// files.
Status ParseSingleExample(dataengine::Example *ptr);
// Name: LoadFeature()
// Description: Given the column type of the tf record and the values list,
// constructs the tensor and returns it.
Status LoadFeature(const dataengine::Feature::KindCase &column_list_type,
const dataengine::Feature &column_values_list, const ColDescriptor &current_col,
std::shared_ptr<Tensor> *out_tensor);
Status LoadBytesList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
std::string *element_str);
Status LoadFloatList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
uint32_t *num_elements, std::unique_ptr<float[]> *float_array);
Status LoadIntList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
uint32_t *num_elements, std::unique_ptr<int64_t[]> *int_array);
Status CreateTensorShapeForColumn(const ColDescriptor &current_col, uint32_t num_elements,
TensorShape *current_shape);
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_BUFFER_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/tf_client.h"
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <limits>
#include <algorithm>
#include "common/utils.h"
#include "proto/example.pb.h"
#include "dataset/engine/datasetops/source/storage_client.h"
#include "dataset/util/path.h"
#include "dataset/util/status.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
// Name: Constructor
// Description: Creates the TFClient.
TFClient::TFClient(std::unique_ptr<DataSchema> schema, // In: The schema for this storage client.
StorageOp *so) // In: The StorageOp that's using this client
: StorageClient(std::move(schema), so),
rows_per_buffer_(so->rows_per_buffer()),
random_seed_generator_(so->seed()),
random_seed_distribution_(0, std::numeric_limits<uint32_t>::max()),
rows_per_shard_(0) {}
Status TFClient::Init() {
// Initialize queue to hold the tf file names
const std::string kExtensionData = ".data";
const std::string kExtensionTF = ".tfrecord";
bool schema_init = false;
if (!storage_op_->dataset_files_dir().empty()) {
MS_LOG(DEBUG) << "Reading dataset using datasetPath.";
Path data_set_directory(storage_op_->dataset_files_dir());
auto dirIt = Path::DirIterator::OpenDirectory(&data_set_directory);
if (dirIt) {
while (dirIt->hasNext()) {
Path file = dirIt->next();
std::string filename = file.toString();
if ((file.Extension() == kExtensionData) || (file.Extension() == kExtensionTF)) {
const std::vector<uint64_t> recs_lengths = ParseTfFileLines(filename);
v_total_file_rows_.emplace_back(
std::pair<std::string, std::vector<uint64_t>>(filename, std::move(recs_lengths)));
// schema
if (!schema_init) {
RETURN_IF_NOT_OK(ParseTfFileSchema(filename));
schema_init = true;
}
MS_LOG(INFO) << "found tf file: " << filename << ", num rows " << recs_lengths.size() << ".";
}
}
} else {
RETURN_STATUS_UNEXPECTED("Unable to open directory " + data_set_directory.toString());
}
} else {
MS_LOG(DEBUG) << "Reading dataset using dataset files list.";
for (auto filename : storage_op_->dataset_file_list()) {
const std::vector<uint64_t> recs_lengths = ParseTfFileLines(filename);
v_total_file_rows_.emplace_back(std::pair<std::string, std::vector<uint64_t>>(filename, std::move(recs_lengths)));
// schema
if (!schema_init) {
RETURN_IF_NOT_OK(ParseTfFileSchema(filename));
schema_init = true;
}
MS_LOG(INFO) << "Processed tf file: " << filename << ", num rows " << recs_lengths.size() << ".";
}
}
RETURN_IF_NOT_OK(CalculateRowsPerDevice());
std::sort(v_total_file_rows_.begin(), v_total_file_rows_.end());
RETURN_IF_NOT_OK(ScatterFileRows(static_cast<uint32_t>(storage_op_->device_id()), storage_op_->shard_config(),
storage_op_->seed(), storage_op_->shuffle_config()));
CalculateNumRows();
InitStateInfo();
return Status::OK();
}
// Sharding will reduce the number of rows. Doing this in constructor as we only want to do this once.
void TFClient::CalculateNumRows() {
num_rows_in_dataset_ = 0;
for (auto rows : file_start_end_offset_) {
num_rows_in_dataset_ += (rows.second - rows.first);
}
}
Status TFClient::CalculateRowsPerDevice() {
uint64_t num = std::accumulate(
v_total_file_rows_.begin(), v_total_file_rows_.end(), 0,
[](uint64_t value, const std::pair<std::string, std::vector<uint64_t>> &a) { return value + a.second.size(); });
if (static_cast<uint64_t>(std::floor(num * 1.0 / storage_op_->device_num())) == 0) {
RETURN_STATUS_UNEXPECTED("Num rows of dataset is less than device number");
}
rows_per_shard_ = static_cast<uint64_t>(std::ceil(num * 1.0 / storage_op_->device_num()));
return Status::OK();
}
bool TFClient::ValidFileForShard(const uint64_t file_rows, uint64_t *start_offset, uint64_t *end_offset,
const uint64_t &pre_count, uint32_t device_id) const {
*start_offset = 0;
*end_offset = 0;
bool valid = false;
uint64_t start_index = device_id * rows_per_shard_;
uint64_t end_index = (device_id + 1) * rows_per_shard_;
// First valid file
if (pre_count <= start_index && pre_count + file_rows > start_index) {
*start_offset = start_index - pre_count;
valid = true;
if (pre_count < end_index && pre_count + file_rows >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = file_rows;
}
}
// Second and subsequent files
if (pre_count > start_index && pre_count < end_index) {
*start_offset = 0;
valid = true;
if (pre_count + file_rows >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = file_rows;
}
}
return valid;
}
void TFClient::GetValidFileForShard(const std::vector<std::pair<std::string, std::vector<uint64_t>>> &v_files,
uint32_t device_id) {
uint64_t start_offset = 0;
uint64_t end_offset = 0;
uint64_t pre_count = 0;
bool finish = false;
while (!finish) {
for (const auto &file : v_files) {
if (ValidFileForShard(file.second.size(), &start_offset, &end_offset, pre_count, device_id)) {
std::pair<uint32_t, uint32_t> offset(start_offset, end_offset);
file_start_end_offset_.emplace_back(offset);
v_file_rows_.emplace_back(file);
}
pre_count += file.second.size();
}
if (pre_count < (device_id + 1) * rows_per_shard_) {
finish = false;
} else {
finish = true;
}
}
}
// Description: Scatter file rows to local single-P according to config info.
// There are 3 modes: ALL, UNIQUE, RANDOM. For UNIQUE and RANDOM mode, shuffleConfig controls
// whether file row vector would be shuffled or not before a new mEopch.
// For ALL mode, temporarily, we deal with epoch in python part.
Status TFClient::ScatterFileRows(uint32_t device_id, const std::string &shard_config, uint32_t seed,
bool shuffle_config) {
if (shard_config == "UNIQUE" || shard_config == "RANDOM") {
std::vector<std::pair<std::string, std::vector<uint64_t>>> v_shuffled_total_file_rows =
ShuffleVector(v_total_file_rows_, seed);
GetValidFileForShard(v_shuffled_total_file_rows, device_id);
if (shuffle_config) {
v_total_file_rows_ = v_shuffled_total_file_rows;
}
} else if (shard_config == "ALL") {
v_file_rows_.insert(v_file_rows_.end(), v_total_file_rows_.begin(), v_total_file_rows_.end());
if (shuffle_config) {
v_total_file_rows_ = ShuffleVector(v_total_file_rows_, seed);
}
for (const auto &file : v_file_rows_) {
std::pair<uint32_t, uint32_t> offset(0, file.second.size());
file_start_end_offset_.emplace_back(offset);
}
} else {
RETURN_STATUS_UNEXPECTED("In parallel config file, wrong shuffleConfig or shardConfig provided.");
}
return Status::OK();
}
std::vector<std::pair<std::string, std::vector<uint64_t>>> TFClient::ShuffleVector(
std::vector<std::pair<std::string, std::vector<uint64_t>>> v, uint32_t seed = 1) {
std::default_random_engine randomEngine(seed);
std::shuffle(std::begin(v), std::end(v), randomEngine);
return v;
}
void TFClient::CalculateStartOffset(const uint64_t start_index, const uint64_t end_index,
const std::vector<uint64_t> &vec_length, uint64_t *start_offset) const {
for (size_t i = start_index; i < end_index; i++) {
// Format of a single record:
// uint64 length
// uint32 masked crc of length
// byte data[length]
// uint32 masked crc of data
*start_offset += sizeof(uint64_t) + 2 * sizeof(uint32_t) + vec_length[i];
}
}
void TFClient::InitStateInfo() {
uint32_t start_idx = 0, record_num = 0, buffer_id = 0;
uint64_t start_offset = 0;
bool first_buffer = true;
f_info_queue_.emplace_back(QFile());
std::vector<std::pair<std::string, std::vector<uint64_t>>>::iterator itr = v_file_rows_.begin();
uint32_t index = 0;
while (itr != v_file_rows_.end()) {
uint32_t file_start_index = file_start_end_offset_[index].first;
uint32_t file_end_index = file_start_end_offset_[index].second;
FileInfo f_info;
f_info.fileName = itr->first;
f_info.startRecordIdx = start_idx > file_start_index ? start_idx : file_start_index;
if (first_buffer && f_info.startRecordIdx != 0) {
CalculateStartOffset(0, f_info.startRecordIdx, itr->second, &start_offset);
start_idx = static_cast<uint32_t>(f_info.startRecordIdx);
}
first_buffer = false;
f_info.startOffset = start_offset;
if (start_idx + rows_per_buffer_ - record_num < itr->second.size()) {
uint64_t end_idx = start_idx + rows_per_buffer_ - record_num - 1;
f_info.endRecordIdx = end_idx > (file_end_index - 1) ? (file_end_index - 1) : end_idx;
f_info_queue_[buffer_id].push(f_info);
CalculateStartOffset(start_idx, f_info.endRecordIdx + 1, itr->second, &start_offset);
start_idx = start_idx + rows_per_buffer_ - record_num;
record_num = 0;
buffer_id++;
f_info_queue_.emplace_back(QFile());
if (end_idx >= file_end_index - 1) {
start_idx = start_offset = 0;
++itr;
++index;
}
} else {
f_info.endRecordIdx = itr->second.size() - 1 > file_end_index - 1 ? file_end_index - 1 : itr->second.size() - 1;
f_info_queue_[buffer_id].push(f_info);
if (start_idx + rows_per_buffer_ - record_num == itr->second.size()) {
record_num = start_idx = start_offset = 0;
buffer_id++;
if (itr + 1 != v_file_rows_.end()) {
f_info_queue_.emplace_back(QFile());
}
} else {
record_num += static_cast<uint32_t>(itr->second.size()) - start_idx;
start_idx = start_offset = 0;
}
++itr;
++index;
}
}
}
// Name: Print()
// Description: A function that prints info about the TFClient
void TFClient::Print(std::ostream &out) const { // In: The output stream to print to
out << "TF client.";
}
std::vector<uint64_t> TFClient::ParseTfFileLines(const std::string &filename) {
std::vector<uint64_t> recs_lengths;
std::ifstream reader;
reader.open(filename);
while (true) {
if (reader.peek() == EOF) {
reader.close();
break;
}
// read length
uint64_t record_length = 0;
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(uint64_t)));
recs_lengths.push_back(record_length);
// ignore crc header
(void)reader.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
// ignore data length
(void)reader.ignore(static_cast<std::streamsize>(record_length));
// ignore crc footer
(void)reader.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
}
return recs_lengths;
}
Status TFClient::ParseTfFileSchema(const std::string &filename) {
std::ifstream reader;
reader.open(filename);
std::string serialized_example;
// read length
uint64_t record_length = 0;
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(uint64_t)));
// ignore crc header
(void)reader.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
// read serialized Example
serialized_example.resize(record_length);
(void)reader.read(&serialized_example[0], static_cast<std::streamsize>(record_length));
// ignore crc footer
(void)reader.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
reader.close();
dataengine::Example tf_file;
if (!tf_file.ParseFromString(serialized_example)) {
std::string err_msg = "parse tf_file failed, file name is " + filename;
RETURN_STATUS_UNEXPECTED(err_msg);
}
const dataengine::Features &example_features = tf_file.features();
const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
for (auto it = feature_map.begin(); it != feature_map.end(); ++it) {
col_names_.push_back(it->first);
}
return Status::OK();
}
// Name: Reset()
// Description: Resets any state info inside the client back to it's initialized
// state.
Status TFClient::Reset() {
v_file_rows_.clear();
file_start_end_offset_.clear();
uint32_t next_seed = random_seed_distribution_(random_seed_generator_);
RETURN_IF_NOT_OK(ScatterFileRows(static_cast<uint32_t>(storage_op_->device_id()), storage_op_->shard_config(),
next_seed, storage_op_->shuffle_config()));
CalculateNumRows();
uint32_t num_rows_in_file = 0;
RETURN_IF_NOT_OK(this->numRowsFromFile(num_rows_in_file));
if (num_rows_in_file < num_rows_in_dataset_) {
num_rows_in_dataset_ = num_rows_in_file;
}
storage_op_->set_num_rows(static_cast<int32_t>(num_rows_in_dataset_));
InitStateInfo();
return Status::OK();
}
Status TFClient::NextFileInfo(uint32_t id, FileInfo *ptr) {
if (f_info_queue_.empty() || id >= f_info_queue_.size() || f_info_queue_[id].empty()) {
RETURN_STATUS_UNEXPECTED("cannot find next FileInfo in mFInfoQueue");
}
*ptr = f_info_queue_[id].front();
f_info_queue_[id].pop();
return Status::OK();
}
bool TFClient::IsMoreData(uint32_t id) { return (!f_info_queue_[id].empty()); }
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_CLIENT_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_CLIENT_H_
#include <fstream>
#include <iostream>
#include <memory>
#include <queue>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include "proto/example.pb.h"
#include "dataset/engine/datasetops/source/storage_client.h"
#include "dataset/util/status.h"
struct FileInfo {
std::string fileName;
uint64_t startRecordIdx;
uint64_t endRecordIdx;
uint64_t startOffset;
};
using QFile = std::queue<FileInfo>;
namespace mindspore {
namespace dataset {
// forward declares
class DataSchema;
class ParallelOp;
class TFClient : public StorageClient {
public:
// Name: Constructor
// Description: Creates the TFClient.
TFClient(std::unique_ptr<DataSchema> schema, // In: The schema for this storage client.
StorageOp *so); // In: The ParallelOp that's using this client
~TFClient() {}
Status Init() override;
// Name: Print()
// Description: A function that prints info about the TFClient
void Print(std::ostream &out) const override; // In: The output stream to print to
std::vector<uint64_t> ParseTfFileLines(const std::string &filename);
Status ParseTfFileSchema(const std::string &filename);
Status NextFileInfo(uint32_t id, FileInfo *);
bool IsMoreData(uint32_t id) override;
// Name: Reset()
// Description: Resets any state info inside the client back to it's initialized
// state.
Status Reset() override;
Status ScatterFileRows(uint32_t device_id, const std::string &shard_config, uint32_t seed, bool shuffle_config);
private:
// hardcoded, put this in json schema
// const static int32_t BERT_DATASET_TOTAL_ROWS = 43900;
uint32_t rows_per_buffer_;
std::default_random_engine random_seed_generator_;
std::uniform_int_distribution<uint32_t> random_seed_distribution_;
std::vector<std::pair<std::string, std::vector<uint64_t>>> v_file_rows_;
std::vector<std::pair<std::string, std::vector<uint64_t>>> v_total_file_rows_;
std::vector<QFile> f_info_queue_;
uint64_t rows_per_shard_;
std::vector<std::pair<uint32_t, uint32_t>> file_start_end_offset_;
void InitStateInfo();
std::vector<std::pair<std::string, std::vector<uint64_t>>> ShuffleVector(
std::vector<std::pair<std::string, std::vector<uint64_t>>> v, uint32_t seed);
Status CalculateRowsPerDevice();
bool ValidFileForShard(const uint64_t file_rows, uint64_t *start_offset, uint64_t *end_offset,
const uint64_t &pre_count, uint32_t device_id) const;
void CalculateNumRows();
void GetValidFileForShard(const std::vector<std::pair<std::string, std::vector<uint64_t>>> &v_files,
uint32_t device_id);
void CalculateStartOffset(const uint64_t start_index, const uint64_t end_index,
const std::vector<uint64_t> &vec_length, uint64_t *start_offset) const;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_CLIENT_H_
......@@ -16,6 +16,7 @@
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include <algorithm>
#include <fstream>
#include <future>
#include <iomanip>
#include <memory>
......@@ -32,8 +33,6 @@
#include "dataset/engine/connector.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/storage_client.h"
#include "dataset/engine/datasetops/source/tf_client.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/jagged_connector.h"
......
......@@ -40,7 +40,7 @@ class TakeOp : public PipelineOp {
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
// @return shared_ptr to the new TakeOp object
Status Build(std::shared_ptr<TakeOp> *);
private:
......
......@@ -65,7 +65,7 @@ class ZipOp : public PipelineOp {
}
// The builder "build" method creates the ZipOp dataset Operator.
// @return shared_ptr to the new StorageOp object
// @return shared_ptr to the new ZipOp object
Status Build(std::shared_ptr<ZipOp> *);
private:
......
......@@ -27,7 +27,6 @@
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/take_op.h"
......
......@@ -33,169 +33,6 @@ valid_detype = [
]
def check(method):
"""Check the function parameters and return the function ."""
func_name = method.__name__
# Required parameter
req_param_int = []
req_param_bool = []
# Non-required parameter
nreq_param_int = []
nreq_param_bool = []
if func_name in 'repeat':
nreq_param_int = ['count', 'prefetch_size']
if func_name in 'take':
req_param_int = ['count']
nreq_param_int = ['prefetch_size']
elif func_name in 'shuffle':
req_param_int = ['buffer_size']
nreq_param_bool = ['reshuffle_each_iteration']
nreq_param_int = ['prefetch_size', 'seed']
elif func_name in 'batch':
req_param_int = ['batch_size']
nreq_param_int = ['num_parallel_workers', 'prefetch_size']
nreq_param_bool = ['drop_remainder']
elif func_name in ('zip', 'filter', 'cache', 'rename', 'project'):
nreq_param_int = ['prefetch_size']
elif func_name in ('map', '__init__'):
nreq_param_int = ['num_parallel_workers', 'prefetch_size', 'seed']
nreq_param_bool = ['block_reader']
@wraps(method)
def wrapper(*args, **kwargs):
def _make_key():
sig = ins.signature(method)
params = sig.parameters
keys = list(params.keys())
param_dic = dict()
for name, value in enumerate(args):
param_dic[keys[name]] = value
param_dic.update(zip(params.keys(), args))
param_dic.update(kwargs)
for name, value in params.items():
if name not in param_dic:
param_dic[name] = value.default
return param_dic
# check type
def _check_param_type(arg, param_name, param_type=None):
if param_type is not None and not isinstance(arg, param_type):
raise ValueError(
"The %s function %s type error!" % (func_name, param_name))
# check range
def _check_param_range(arg, param_name):
if isinstance(arg, int) and param_name == "seed" and (
arg < 0 or arg > 2147483647):
raise ValueError(
"The %s function %s exceeds the boundary!" % (
func_name, param_name))
if isinstance(arg, int) and param_name == "count" and ((arg <= 0 and arg != -1) or arg > 2147483647):
raise ValueError(
"The %s function %s exceeds the boundary!" % (
func_name, param_name))
if isinstance(arg, int) and param_name == "prefetch_size" and (
arg <= 0 or arg > 1024):
raise ValueError(
"The %s function %s exceeds the boundary!" % (
func_name, param_name))
if isinstance(arg, int) and param_name == "num_parallel_workers" and (
arg < 1 or arg > cpu_count()):
raise ValueError(
"The %s function %s exceeds the boundary(%s)!" % (
func_name, param_name, cpu_count()))
if isinstance(arg, int) and param_name != "seed" \
and param_name != "count" and param_name != "prefetch_size" \
and param_name != "num_parallel_workers" and (arg < 1 or arg > 2147483647):
raise ValueError(
"The %s function %s exceeds the boundary!" % (
func_name, param_name))
key = _make_key()
# check integer
for karg in req_param_int:
_check_param_type(key[karg], karg, int)
_check_param_range(key[karg], karg)
for karg in nreq_param_int:
if karg in key:
if key[karg] is not None:
_check_param_type(key[karg], karg, int)
_check_param_range(key[karg], karg)
# check bool
for karg in req_param_bool:
_check_param_type(key[karg], karg, bool)
for karg in nreq_param_bool:
if karg in key:
if key[karg] is not None:
_check_param_type(key[karg], karg, bool)
if func_name in '__init__':
if 'columns_list' in key.keys():
columns_list = key['columns_list']
if columns_list is not None:
_check_param_type(columns_list, 'columns_list', list)
if 'columns' in key.keys():
columns = key['columns']
if columns is not None:
_check_param_type(columns, 'columns', list)
if 'partitions' in key.keys():
partitions = key['partitions']
if partitions is not None:
_check_param_type(partitions, 'partitions', list)
if 'schema' in key.keys():
schema = key['schema']
if schema is not None:
check_filename(schema)
if not os.path.isfile(schema) or not os.access(schema, os.R_OK):
raise ValueError(
"The file %s does not exist or permission denied!" % schema)
if 'dataset_dir' in key.keys():
dataset_dir = key['dataset_dir']
if dataset_dir is not None:
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
raise ValueError(
"The folder %s does not exist or permission denied!" % dataset_dir)
if 'dataset_files' in key.keys():
dataset_files = key['dataset_files']
if not dataset_files:
raise ValueError(
"The dataset file does not exists!")
if dataset_files is not None:
_check_param_type(dataset_files, 'dataset_files', list)
for file in dataset_files:
if not os.path.isfile(file) or not os.access(file, os.R_OK):
raise ValueError(
"The file %s does not exist or permission denied!" % file)
if 'dataset_file' in key.keys():
dataset_file = key['dataset_file']
if not dataset_file:
raise ValueError(
"The dataset file does not exists!")
check_filename(dataset_file)
if dataset_file is not None:
if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
raise ValueError(
"The file %s does not exist or permission denied!" % dataset_file)
return method(*args, **kwargs)
return wrapper
def check_valid_detype(type_):
if type_ not in valid_detype:
raise ValueError("Unknown column type")
......
......@@ -48,7 +48,6 @@ SET(DE_UT_SRCS
shuffle_op_test.cc
stand_alone_samplers_test.cc
status_test.cc
storage_op_test.cc
task_manager_test.cc
tensor_test.cc
tensor_string_test.cc
......
......@@ -54,10 +54,10 @@ std::shared_ptr<de::RepeatOp> Repeat(int repeat_cnt = 1) {
return op;
}
std::shared_ptr<de::StorageOp> Storage(std::string schema, int rows_per_buf = 2, int num_works = 8) {
std::shared_ptr<de::StorageOp> so;
de::StorageOp::Builder builder;
builder.SetDatasetFilesDir(schema).SetRowsPerBuffer(rows_per_buf).SetNumWorkers(num_works);
std::shared_ptr<de::TFReaderOp> TFReader(std::string schema, int rows_per_buf = 2, int num_works = 8) {
std::shared_ptr<de::TFReaderOp> so;
de::TFReaderOp::Builder builder;
builder.SetDatasetFilesList({schema}).SetRowsPerBuffer(rows_per_buf).SetNumWorkers(num_works);
Status rc = builder.Build(&so);
return so;
}
......@@ -77,9 +77,9 @@ std::shared_ptr<de::ExecutionTree> Build(std::vector<std::shared_ptr<de::Dataset
}
TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({Storage(schema_file), Batch(12)});
auto tree = Build({TFReader(schema_file), Batch(12)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
......@@ -108,9 +108,9 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
}
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({Storage(schema_file), Repeat(2), Batch(7, true, 99)});
auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, true, 99)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
......@@ -153,9 +153,9 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
}
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({Storage(schema_file), Repeat(2), Batch(7, false, 99)});
auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, false, 99)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
......@@ -205,9 +205,9 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
}
TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({Storage(schema_file), Batch(7, false, 99), Repeat(2)});
auto tree = Build({TFReader(schema_file), Batch(7, false, 99), Repeat(2)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
......@@ -251,9 +251,9 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
}
TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({Storage(schema_file), Batch(5, true, 99), Repeat(2)});
auto tree = Build({TFReader(schema_file), Batch(5, true, 99), Repeat(2)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
......@@ -297,7 +297,7 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
}
TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<BatchOp> op;
PadInfo m;
std::shared_ptr<Tensor> pad_value;
......@@ -305,7 +305,7 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) {
pad_value->SetItemAt<float>({}, -1);
m.insert({"col_1d", std::make_pair(TensorShape({4}), pad_value)});
de::BatchOp::Builder(12).SetDrop(false).SetPaddingMap(m, true).Build(&op);
auto tree = Build({Storage(schema_file), op});
auto tree = Build({TFReader(schema_file), op});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
......
......@@ -88,17 +88,17 @@ TEST_F(MindDataTestClientConfig, TestClientConfig2) {
// Dataset from testDataset1 has 10 rows, 2 columns.
// RowsPerBuffer buffer setting of 2 divides evenly into total rows.
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testDataset1";
std::shared_ptr<StorageOp> my_storage_op;
StorageOp::Builder builder;
builder.SetDatasetFilesDir(dataset_path);
rc = builder.Build(&my_storage_op);
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({dataset_path});
rc = builder.Build(&my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
ASSERT_EQ(my_storage_op->num_workers(),16);
my_tree->AssociateNode(my_storage_op);
ASSERT_EQ(my_tfreader_op->num_workers(),1);
my_tree->AssociateNode(my_tfreader_op);
// Set children/root layout.
my_tree->AssignRoot(my_storage_op);
my_tree->AssignRoot(my_tfreader_op);
my_tree->Prepare();
my_tree->Launch();
......@@ -116,5 +116,5 @@ TEST_F(MindDataTestClientConfig, TestClientConfig2) {
row_count++;
}
ASSERT_EQ(row_count, 10); // Should be 10 rows fetched
ASSERT_EQ(my_storage_op->num_workers(),16);
ASSERT_EQ(my_tfreader_op->num_workers(),1);
}
......@@ -18,7 +18,7 @@
#include "dataset/core/client.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "dataset/util/de_error.h"
......@@ -103,17 +103,17 @@ TEST_F(MindDataTestExecutionTree, TestExecutionTree2) {
Status rc;
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path = datasets_root_path_ + "/testDataset1";
std::shared_ptr<StorageOp> my_storage_op;
StorageOp::Builder()
.SetDatasetFilesDir(dataset_path)
std::string dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(2)
.SetNumWorkers(2)
.Build(&my_storage_op);
.Build(&my_tfreader_op);
my_tree->AssociateNode(my_storage_op);
my_tree->AssignRoot(my_storage_op);
my_tree->AssociateNode(my_tfreader_op);
my_tree->AssignRoot(my_tfreader_op);
// prepare the tree
my_tree->Prepare();
......
......@@ -91,7 +91,8 @@ class MindDataTestMapOp : public UT::DatasetOpTesting {
public:
void SetUp() override {
DatasetOpTesting::SetUp();
dataset_path_ = datasets_root_path_ + "" + "/testDataset2";
dataset_path_ = datasets_root_path_ + "" + "/testDataset2/testDataset2.data";
schema_path_ = datasets_root_path_ + "" + "/testDataset2/datasetSchema.json";
GlobalInit();
......@@ -99,22 +100,28 @@ class MindDataTestMapOp : public UT::DatasetOpTesting {
my_tree_ = std::make_shared<ExecutionTree>();
}
std::shared_ptr<StorageOp> CreateStorageOp() {
std::shared_ptr<StorageOp> my_storage_op;
StorageOp::Builder builder;
builder.SetDatasetFilesDir(dataset_path_)
std::shared_ptr<TFReaderOp> CreateTFReaderOp() {
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({dataset_path_})
.SetColumnsToLoad({"image", "label", "A", "B"})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(2)
.SetNumWorkers(2);
Status rc = builder.Build(&my_storage_op);
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
schema->LoadSchemaFile(schema_path_, {});
builder.SetDataSchema(std::move(schema));
Status rc = builder.Build(&my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
return my_storage_op;
return my_tfreader_op;
}
std::shared_ptr<ExecutionTree> my_tree_;
private:
std::string dataset_path_;
std::string schema_path_;
};
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
......@@ -124,7 +131,7 @@ std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int6
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
// TestByPosition scenario:
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
// A TensorOp that does nothing picks the label column and output a column also named label.
// Thus, based on the new MapOp behaviour, the column ordering will be |image|label|A|B|.
// Verify the column ordering based on the Tensor properties matching to that of in the schema file.
......@@ -132,10 +139,10 @@ TEST_F(MindDataTestMapOp, TestByPosition) {
Status rc;
MS_LOG(INFO) << "Doing TestByPosition.";
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total
// of 10 rows.
auto my_storage_op = this->CreateStorageOp();
rc = my_tree_->AssociateNode(my_storage_op);
auto my_tfreader_op = this->CreateTFReaderOp();
rc = my_tree_->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
auto my_no_op = std::make_shared<mindspore::dataset::test::NoOp>();
std::vector<std::shared_ptr<TensorOp>> my_func_list;
......@@ -144,13 +151,14 @@ TEST_F(MindDataTestMapOp, TestByPosition) {
MapOp::Builder builder;
builder.SetInColNames({"label"})
.SetOutColNames({})
.SetColOrder({"image", "label", "A", "B"})
.SetTensorFuncs(std::move(my_func_list))
.SetNumWorkers(100);
rc = builder.Build(&my_map_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssociateNode(my_map_op);
EXPECT_TRUE(rc.IsOk());
rc = my_map_op->AddChild(my_storage_op);
rc = my_map_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssignRoot(my_map_op);
EXPECT_TRUE(rc.IsOk());
......@@ -192,7 +200,7 @@ TEST_F(MindDataTestMapOp, TestByPosition) {
}
// TestAsMap scenario:
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
// A TensorOp that does nothing picks the "image" column and produces a column named "X".
// Thus, based on the new MapOp behaviour, the column ordering will be |X|label|A|B|.
// Verify that the "image" column is removed and "X" column is added.
......@@ -200,9 +208,9 @@ TEST_F(MindDataTestMapOp, TestAsMap) {
Status rc;
MS_LOG(INFO) << "Doing TestAsMap.";
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total of 10 rows.
auto my_storage_op = this->CreateStorageOp();
rc = my_tree_->AssociateNode(my_storage_op);
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total of 10 rows.
auto my_tfreader_op = this->CreateTFReaderOp();
rc = my_tree_->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
auto my_no_op = std::make_shared<mindspore::dataset::test::NoOp>();
std::vector<std::shared_ptr<TensorOp>> my_func_list;
......@@ -216,7 +224,7 @@ TEST_F(MindDataTestMapOp, TestAsMap) {
rc = builder.Build(&my_map_op);
rc = my_tree_->AssociateNode(my_map_op);
EXPECT_TRUE(rc.IsOk());
rc = my_map_op->AddChild(my_storage_op);
rc = my_map_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
// Assign the tree root
......@@ -243,7 +251,7 @@ TEST_F(MindDataTestMapOp, TestAsMap) {
}
// Test3to1 scenario:
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
// A 3-to-1 TensorOp picks the columns [image, A, B] and produce a column named "X".
// Thus, based on the new MapOp behaviour, the column ordering will be |X|label|.
// Verify that the only columns "X" and "label" exist.
......@@ -251,9 +259,9 @@ TEST_F(MindDataTestMapOp, Test3to1) {
Status rc;
MS_LOG(INFO) << "Doing Test3to1.";
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total of 10 rows.
auto my_storage_op = this->CreateStorageOp();
rc = my_tree_->AssociateNode(my_storage_op);
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total of 10 rows.
auto my_tfreader_op = this->CreateTFReaderOp();
rc = my_tree_->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
auto my_op = std::make_shared<mindspore::dataset::test::ThreeToOneOp>();
std::vector<std::shared_ptr<TensorOp>> my_func_list;
......@@ -268,7 +276,7 @@ TEST_F(MindDataTestMapOp, Test3to1) {
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssociateNode(my_map_op);
EXPECT_TRUE(rc.IsOk());
rc = my_map_op->AddChild(my_storage_op);
rc = my_map_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssignRoot(my_map_op);
EXPECT_TRUE(rc.IsOk());
......@@ -295,7 +303,7 @@ TEST_F(MindDataTestMapOp, Test3to1) {
}
// Test1to3 scenario:
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
// A 1-to-3 TensorOp picks the columns [image] and produce a column named [X, Y, Z].
// Thus, based on the new MapOp behaviour, the column ordering will be |X|Y|Z|label|A|B|.
// Verify that the only columns X, Y, Z are added (to the front) and followed by columns label, A, B..
......@@ -303,9 +311,9 @@ TEST_F(MindDataTestMapOp, Test1to3) {
Status rc;
MS_LOG(INFO) << "Doing Test1to3.";
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total of 10 rows.
auto my_storage_op = this->CreateStorageOp();
rc = my_tree_->AssociateNode(my_storage_op);
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total of 10 rows.
auto my_tfreader_op = this->CreateTFReaderOp();
rc = my_tree_->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
auto my_op = std::make_shared<mindspore::dataset::test::OneToThreeOp>();
std::vector<std::shared_ptr<TensorOp>> my_func_list;
......@@ -316,12 +324,25 @@ TEST_F(MindDataTestMapOp, Test1to3) {
.SetOutColNames({"X", "Y", "Z"})
.SetTensorFuncs(std::move(my_func_list))
.SetNumWorkers(1);
// ProjectOp
std::vector<std::string> columns_to_project = {"X", "Y", "Z", "label", "A", "B"};
std::shared_ptr<ProjectOp> my_project_op = std::make_shared<ProjectOp>(columns_to_project);
rc = my_tree_->AssociateNode(my_project_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree_->AssignRoot(my_project_op);
ASSERT_TRUE(rc.IsOk());
rc = builder.Build(&my_map_op);
rc = my_tree_->AssociateNode(my_map_op);
EXPECT_TRUE(rc.IsOk());
rc = my_map_op->AddChild(my_storage_op);
rc = my_project_op->AddChild(my_map_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssignRoot(my_map_op);
rc = my_map_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Prepare();
EXPECT_TRUE(rc.IsOk());
......@@ -371,7 +392,7 @@ TEST_F(MindDataTestMapOp, Test1to3) {
}
// TestMultiTensorOp scenario:
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
// A series of 3-to-1 and 1-to-3 TensorOps are applied to [image, A, B] and
// produce final output columns [X, Y, Z].
// Based on the new MapOp behaviour, the column ordering will be |X|Y|Z|label|.
......@@ -379,9 +400,9 @@ TEST_F(MindDataTestMapOp, TestMultiTensorOp) {
Status rc;
MS_LOG(INFO) << "Doing TestMultiTensorOp.";
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total of 10 rows.
auto my_storage_op = this->CreateStorageOp();
rc = my_tree_->AssociateNode(my_storage_op);
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total of 10 rows.
auto my_tfreader_op = this->CreateTFReaderOp();
rc = my_tree_->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
auto my_op1 = std::make_shared<mindspore::dataset::test::ThreeToOneOp>();
auto my_op2 = std::make_shared<mindspore::dataset::test::OneToThreeOp>();
......@@ -398,7 +419,7 @@ TEST_F(MindDataTestMapOp, TestMultiTensorOp) {
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssociateNode(my_map_op);
EXPECT_TRUE(rc.IsOk());
rc = my_map_op->AddChild(my_storage_op);
rc = my_map_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssignRoot(my_map_op);
EXPECT_TRUE(rc.IsOk());
......@@ -431,15 +452,15 @@ TEST_F(MindDataTestMapOp, TestMultiTensorOp) {
}
}
TEST_F(MindDataTestMapOp, TestStorageRepeatMap) {
TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) {
Status rc;
MS_LOG(INFO) << "Doing TestStorageRepeatMap.";
MS_LOG(INFO) << "Doing TestTFReaderRepeatMap.";
uint32_t num_repeats = 3;
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total
// of 10 rows.
auto my_storage_op = this->CreateStorageOp();
rc = my_tree_->AssociateNode(my_storage_op);
auto my_tfreader_op = this->CreateTFReaderOp();
rc = my_tree_->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
auto my_no_op = std::make_shared<mindspore::dataset::test::NoOp>();
std::vector<std::shared_ptr<TensorOp>> my_func_list;
......@@ -465,7 +486,7 @@ TEST_F(MindDataTestMapOp, TestStorageRepeatMap) {
rc = my_map_op->AddChild(my_repeat_op);
EXPECT_TRUE(rc.IsOk());
rc = my_repeat_op->AddChild(my_storage_op);
rc = my_repeat_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssignRoot(my_map_op);
......@@ -493,15 +514,15 @@ TEST_F(MindDataTestMapOp, TestStorageRepeatMap) {
ASSERT_EQ(row_count, 10 * num_repeats);
}
TEST_F(MindDataTestMapOp, TestStorageMapRepeat) {
TEST_F(MindDataTestMapOp, TestTFReaderMapRepeat) {
Status rc;
MS_LOG(INFO) << "Doing TestStorageMapRepeat.";
MS_LOG(INFO) << "Doing TestTFReaderMapRepeat.";
uint32_t num_repeats = 3;
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total
// of 10 rows.
auto my_storage_op = this->CreateStorageOp();
rc = my_tree_->AssociateNode(my_storage_op);
auto my_tfreader_op = this->CreateTFReaderOp();
rc = my_tree_->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
auto my_no_op = std::make_shared<mindspore::dataset::test::NoOp>();
std::vector<std::shared_ptr<TensorOp>> my_func_list;
......@@ -527,7 +548,7 @@ TEST_F(MindDataTestMapOp, TestStorageMapRepeat) {
rc = my_repeat_op->AddChild(my_map_op);
EXPECT_TRUE(rc.IsOk());
rc = my_map_op->AddChild(my_storage_op);
rc = my_map_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssignRoot(my_repeat_op);
......@@ -554,23 +575,23 @@ TEST_F(MindDataTestMapOp, TestStorageMapRepeat) {
ASSERT_EQ(row_count, 10 * num_repeats);
}
TEST_F(MindDataTestMapOp, Storage_Decode_Repeat_Resize) {
TEST_F(MindDataTestMapOp, TFReader_Decode_Repeat_Resize) {
Status rc;
MS_LOG(INFO) << "Doing Storage_Decode_Repeat_Resize.";
MS_LOG(INFO) << "Doing TFReader_Decode_Repeat_Resize.";
uint32_t num_repeats = 2;
std::string dataset_path_ = datasets_root_path_ + "/" + "test_tf_file_3_images";
std::shared_ptr<StorageOp> my_storage_op;
StorageOp::Builder sobuilder;
sobuilder.SetDatasetFilesDir(dataset_path_)
std::string dataset_path_ = datasets_root_path_ + "/" + "test_tf_file_3_images/train-0000-of-0001.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder sobuilder;
sobuilder.SetDatasetFilesList({dataset_path_})
.SetColumnsToLoad({"image", "label"})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(2)
.SetNumWorkers(2);
rc = sobuilder.Build(&my_storage_op);
rc = sobuilder.Build(&my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->AssociateNode(my_storage_op);
rc = my_tree_->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
auto decode_op = std::make_shared<DecodeOp>();
std::vector<std::shared_ptr<TensorOp>> my_func_list;
......@@ -608,7 +629,7 @@ TEST_F(MindDataTestMapOp, Storage_Decode_Repeat_Resize) {
rc = my_tree_->AssociateNode(my_map_resize_op);
EXPECT_TRUE(rc.IsOk());
rc = my_map_decode_op->AddChild(my_storage_op);
rc = my_map_decode_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_repeat_op->AddChild(my_map_decode_op);
......
......@@ -44,23 +44,23 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
//
// OpId(2) RenameOp
// |
// OpId(0) StorageOp
// OpId(0) TFReaderOp
// Start with an empty execution tree
Status rc;
MS_LOG(INFO) << "UT test TestRenameBasic.";
auto my_tree = std::make_shared<ExecutionTree>();
// Creating StorageOp
// Creating TFReaderOp
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1";
std::shared_ptr<StorageOp> my_storage_op;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path)
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(16)
.SetNumWorkers(1)
.Build(&my_storage_op);
.Build(&my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_storage_op);
rc = my_tree->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
// Creating DatasetOp
......@@ -76,7 +76,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
rc = my_tree->AssociateNode(rename_op);
EXPECT_TRUE(rc.IsOk());
rc = rename_op->AddChild(std::move(my_storage_op));
rc = rename_op->AddChild(std::move(my_tfreader_op));
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(rename_op);
EXPECT_TRUE(rc.IsOk());
......
......@@ -39,11 +39,11 @@ class MindDataTestShuffleOp : public UT::DatasetOpTesting {
// - RowsPerBuffer buffer setting of 2 divides evenly into total rows.
// - Shuffle size is multiple of rows per buffer.
//
// Tree: shuffle over storage
// Tree: shuffle over TFReader
//
// ShuffleOp
// |
// StorageOp
// TFReaderOp
//
TEST_F(MindDataTestShuffleOp, TestShuffleBasic1) {
Status rc;
......@@ -53,16 +53,16 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic1) {
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testDataset1";
std::shared_ptr<StorageOp> my_storage_op;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path)
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(16)
.SetNumWorkers(1)
.Build(&my_storage_op);
.Build(&my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_storage_op);
rc = my_tree->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<ShuffleOp> my_shuffle_op;
rc = ShuffleOp::Builder().SetRowsPerBuffer(2).SetShuffleSize(4).Build(&my_shuffle_op);
......@@ -71,7 +71,7 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic1) {
EXPECT_TRUE(rc.IsOk());
// Set children/root layout.
rc = my_shuffle_op->AddChild(my_storage_op);
rc = my_shuffle_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_shuffle_op);
EXPECT_TRUE(rc.IsOk());
......@@ -112,11 +112,11 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic1) {
// - Shuffle size is not a multiple of rows per buffer.
// - User has provided a non-default seed value.
//
// Tree: shuffle over storage
// Tree: shuffle over TFReader
//
// ShuffleOp
// |
// StorageOp
// TFReaderOp
//
TEST_F(MindDataTestShuffleOp, TestShuffleBasic2) {
Status rc;
......@@ -126,16 +126,16 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic2) {
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testDataset1";
std::shared_ptr<StorageOp> my_storage_op;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path)
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(3)
.SetWorkerConnectorSize(16)
.SetNumWorkers(2)
.Build(&my_storage_op);
.Build(&my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_storage_op);
rc = my_tree->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<ShuffleOp> my_shuffle_op;
rc = ShuffleOp::Builder().SetShuffleSize(4).SetShuffleSeed(100).SetRowsPerBuffer(3).Build(&my_shuffle_op);
......@@ -144,7 +144,7 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic2) {
EXPECT_TRUE(rc.IsOk());
// Set children/root layout.
rc = my_shuffle_op->AddChild(my_storage_op);
rc = my_shuffle_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_shuffle_op);
EXPECT_TRUE(rc.IsOk());
......@@ -183,11 +183,11 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic2) {
// - Shuffle size captures the entire dataset size (actually sets a value that is larger than the
// amount of rows in the dataset.
//
// Tree: shuffle over storage
// Tree: shuffle over TFReader
//
// ShuffleOp
// |
// StorageOp
// TFReaderOp
//
TEST_F(MindDataTestShuffleOp, TestShuffleBasic3) {
Status rc;
......@@ -197,16 +197,16 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic3) {
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testDataset1";
std::shared_ptr<StorageOp> my_storage_op;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path)
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(3)
.SetWorkerConnectorSize(16)
.SetNumWorkers(2)
.Build(&my_storage_op);
.Build(&my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
my_tree->AssociateNode(my_storage_op);
my_tree->AssociateNode(my_tfreader_op);
std::shared_ptr<ShuffleOp> my_shuffle_op;
rc = ShuffleOp::Builder().SetShuffleSize(100).SetRowsPerBuffer(3).Build(&my_shuffle_op);
EXPECT_TRUE(rc.IsOk());
......@@ -214,7 +214,7 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic3) {
EXPECT_TRUE(rc.IsOk());
// Set children/root layout.
rc = my_shuffle_op->AddChild(my_storage_op);
rc = my_shuffle_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_shuffle_op);
EXPECT_TRUE(rc.IsOk());
......@@ -255,13 +255,13 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic3) {
// - shuffle seed is given, and subsequent epochs will change the seed each time.
// - Repeat count of 2
//
// Tree: Repeat over shuffle over storage
// Tree: Repeat over shuffle over TFReader
//
// Repeat
// |
// shuffle
// |
// StorageOp
// TFReaderOp
//
TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) {
Status rc;
......@@ -271,16 +271,16 @@ TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) {
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testDataset1";
std::shared_ptr<StorageOp> my_storage_op;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path)
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(3)
.SetWorkerConnectorSize(16)
.SetNumWorkers(2)
.Build(&my_storage_op);
.Build(&my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_storage_op);
rc = my_tree->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<ShuffleOp> my_shuffle_op;
rc = ShuffleOp::Builder()
......@@ -302,7 +302,7 @@ TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) {
// Set children/root layout.
rc = my_repeat_op->AddChild(my_shuffle_op);
EXPECT_TRUE(rc.IsOk());
rc = my_shuffle_op->AddChild(my_storage_op);
rc = my_shuffle_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_repeat_op);
EXPECT_TRUE(rc.IsOk());
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include <memory>
#include <vector>
#include <iostream>
namespace common = mindspore::common;
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestStorageOp : public UT::DatasetOpTesting {
};
TEST_F(MindDataTestStorageOp, TestStorageBasic1) {
// single storage op and nothing else
//
// StorageOp
MS_LOG(INFO) << "UT test TestStorageBasic1.";
Status rc;
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
// Test info:
// Dataset from testDataset1 has 10 rows, 2 columns.
// RowsPerBuffer buffer setting of 2 divides evenly into total rows.
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testDataset1";
std::shared_ptr<StorageOp> my_storage_op;
StorageOp::Builder builder;
builder.SetDatasetFilesDir(dataset_path)
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(16)
.SetNumWorkers(1);
rc = builder.Build(&my_storage_op);
ASSERT_TRUE(rc.IsOk());
my_tree->AssociateNode(my_storage_op);
// Set children/root layout.
my_tree->AssignRoot(my_storage_op);
MS_LOG(INFO) << "Launching tree and begin iteration.";
my_tree->Prepare();
my_tree->Launch();
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
MS_LOG(INFO) << "Row display for row #: " << row_count << ".";
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << common::SafeCStr(ss.str()) << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 10); // Should be 10 rows fetched
// debugging temp. what happens if we keep fetching..
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestStorageOp, TestStorageBasic2) {
// single storage op and nothing else
//
// StorageOp
MS_LOG(INFO) << "UT test TestStorageBasic1.";
Status rc;
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
// Test info:
// Dataset from testDataset1 has 10 rows, 2 columns.
// RowsPerBuffer buffer setting of 3 yields 4 buffers with the last buffer having single row
// only. 2 workers.
// Test a column selection instead of all columns as well.
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testDataset1";
std::vector<std::string> column_list;
std::string label_colname("label");
column_list.push_back(label_colname);
std::shared_ptr<StorageOp> my_storage_op;
StorageOp::Builder builder;
builder.SetDatasetFilesDir(dataset_path)
.SetRowsPerBuffer(3)
.SetWorkerConnectorSize(16)
.SetNumWorkers(2)
.SetColumnsToLoad(column_list);
rc = builder.Build(&my_storage_op);
ASSERT_TRUE(rc.IsOk());
my_tree->AssociateNode(my_storage_op);
// Set children/root layout.
my_tree->AssignRoot(my_storage_op);
MS_LOG(INFO) << "Launching tree and begin iteration.";
my_tree->Prepare();
my_tree->Launch();
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
MS_LOG(INFO) << "Row display for row #: " << row_count << ".";
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << common::SafeCStr(ss.str()) << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 10); // Should be 10 rows fetched
}
......@@ -51,35 +51,35 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
*
* OpId(2) ZipOp
* / \
* OpId(0) StorageOp OpId(1) StorageOp
* OpId(0) TFReaderOp OpId(1) TFReaderOp
* Start with an empty execution tree
*/
Status rc;
MS_LOG(INFO) << "UT test TestZipBasic.";
auto my_tree = std::make_shared<ExecutionTree>();
// Creating StorageOp
// Creating TFReaderOp
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1";
std::string dataset_path2 = datasets_root_path_ + "/test_tf_file_3_images_2";
std::shared_ptr<StorageOp> my_storage_op;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path)
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(16)
.SetNumWorkers(1)
.Build(&my_storage_op);
.Build(&my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_storage_op);
rc = my_tree->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<StorageOp> my_storage_op2;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path2)
std::shared_ptr<TFReaderOp> my_tfreader_op2;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path2})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(1)
.SetNumWorkers(1)
.Build(&my_storage_op2);
.Build(&my_tfreader_op2);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_storage_op2);
rc = my_tree->AssociateNode(my_tfreader_op2);
EXPECT_TRUE(rc.IsOk());
// Creating DatasetOp
......@@ -89,9 +89,9 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
rc = my_tree->AssociateNode(zip_op);
EXPECT_TRUE(rc.IsOk());
rc = zip_op->AddChild(std::move(my_storage_op));
rc = zip_op->AddChild(std::move(my_tfreader_op));
EXPECT_TRUE(rc.IsOk());
rc = zip_op->AddChild(std::move(my_storage_op2));
rc = zip_op->AddChild(std::move(my_tfreader_op2));
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(zip_op);
EXPECT_TRUE(rc.IsOk());
......@@ -125,6 +125,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
EXPECT_TRUE(rc.IsOk());
row_count++;
}
MS_LOG(WARNING) <<"row count is: " << row_count;
ASSERT_EQ(row_count, 3); // Should be 3 rows fetched
}
......@@ -135,7 +136,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
*
* OpId(2) ZipOp
* / \
* OpId(0) StorageOp OpId(1) StorageOp
* OpId(0) TFReaderOp OpId(1) TFReaderOp
*
* Start with an empty execution tree
*/
......@@ -143,27 +144,27 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
MS_LOG(INFO) << "UT test TestZipRepeat.";
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1";
std::string dataset_path2 = datasets_root_path_ + "/test_tf_file_3_images_2";
std::shared_ptr<StorageOp> my_storage_op;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path)
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(16)
.SetNumWorkers(1)
.Build(&my_storage_op);
.Build(&my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_storage_op);
rc = my_tree->AssociateNode(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
std::shared_ptr<StorageOp> my_storage_op2;
rc = StorageOp::Builder()
.SetDatasetFilesDir(dataset_path2)
std::shared_ptr<TFReaderOp> my_tfreader_op2;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path2})
.SetRowsPerBuffer(2)
.SetWorkerConnectorSize(1)
.SetNumWorkers(1)
.Build(&my_storage_op2);
.Build(&my_tfreader_op2);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_storage_op2);
rc = my_tree->AssociateNode(my_tfreader_op2);
EXPECT_TRUE(rc.IsOk());
// Creating DatasetOp
std::shared_ptr<ZipOp> zip_op;
......@@ -171,9 +172,9 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(zip_op);
EXPECT_TRUE(rc.IsOk());
rc = zip_op->AddChild(std::move(my_storage_op));
rc = zip_op->AddChild(std::move(my_tfreader_op));
EXPECT_TRUE(rc.IsOk());
rc = zip_op->AddChild(std::move(my_storage_op2));
rc = zip_op->AddChild(std::move(my_tfreader_op2));
EXPECT_TRUE(rc.IsOk());
// Builder(num_of_repeats)
......
{
"deviceNum":3,
"deviceId":1,
"shardConfig":"ALL",
"shuffle":"ON",
"seed": 0,
"epoch": 2
}
{
"deviceNum":7,
"deviceId":6,
"shardConfig":"RANDOM",
"shuffle":"ON",
"seed": 0,
"epoch": 1
}
{
"deviceNum":3,
"deviceId":1,
"shardConfig":"RANDOM",
"shuffle":"ON",
"seed": 0,
"epoch": 1
}
{
"deviceNum":3,
"deviceId":1,
"shardConfig":"UNIQUE",
"shuffle":"ON",
"seed": 0,
"epoch": 3
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册