提交 526770e0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3049 [dataset] add save operator in dataset

Merge pull request !3049 from liyong126/dataset_save_op
......@@ -42,11 +42,17 @@
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_category.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_header.h"
#include "minddata/mindrecord/include/shard_index_generator.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "minddata/mindrecord/include/shard_writer.h"
#include "pybind11/stl.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
using json = nlohmann::json;
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *, std::shared_ptr<DatasetOp> *);
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
......@@ -355,6 +361,226 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
return Status::OK();
}
Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type) {
Status s;
auto mr_header = std::make_shared<mindrecord::ShardHeader>();
auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
std::vector<std::string> blob_fields;
uint64_t mr_schema_id = 0;
if (mindrecord::SUCCESS != mindrecord::ShardWriter::initialize(&mr_writer, file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter.");
}
TensorRow row;
std::unordered_map<std::string, int32_t> column_name_id_map =
iterator_->GetColumnNameMap(); // map of column name, id
bool first_loop = true; // build schema in first loop
do {
json row_raw_data;
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
{
py::gil_scoped_release gil_release;
s = iterator_->FetchNextTensorRow(&row);
}
RETURN_IF_NOT_OK(s);
if (row.empty()) break;
if (first_loop) {
json mr_json;
std::vector<std::string> index_fields;
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
RETURN_IF_NOT_OK(s);
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id);
mr_writer->SetShardHeader(mr_header);
first_loop = false;
}
// construct data
if (!row.empty()) { // write data
s = FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data);
RETURN_IF_NOT_OK(s);
std::shared_ptr<std::vector<uint8_t>> output_bin_data;
mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data);
std::map<std::uint64_t, std::vector<json>> raw_data;
raw_data.insert(std::pair<uint64_t, std::vector<json>>(mr_schema_id, std::vector<json>{row_raw_data}));
std::vector<std::vector<uint8_t>> bin_data;
if (nullptr != output_bin_data) {
bin_data.emplace_back(*output_bin_data);
}
mr_writer->WriteRawData(raw_data, bin_data);
}
} while (!row.empty());
mr_writer->Commit();
mindrecord::ShardIndexGenerator::finalize(file_names);
return Status::OK();
}
Status DEPipeline::FetchDataFromTensorRow(const TensorRow &row,
const std::unordered_map<std::string, int32_t> &column_name_id_map,
json *row_raw_data,
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data) {
if (row_raw_data == nullptr) {
RETURN_STATUS_UNEXPECTED("error: row raw data is NULL.");
}
if (row_bin_data == nullptr) {
RETURN_STATUS_UNEXPECTED("error: row bin data is NULL.");
}
if (column_name_id_map.empty()) {
RETURN_STATUS_UNEXPECTED("Error: column not found");
}
Status s;
for (auto &col : column_name_id_map) {
auto idx = col.second;
auto column_name = col.first;
auto &tensor = row[idx];
auto column_type = tensor->type();
std::unique_ptr<std::vector<uint8_t>> data_ptr;
if (column_type == DataType::DE_INT8) {
std::unique_ptr<int32_t> data;
std::unique_ptr<int8_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT16) {
std::unique_ptr<int32_t> data;
std::unique_ptr<int16_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT16) {
std::unique_ptr<int32_t> data;
std::unique_ptr<uint16_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT8) {
std::unique_ptr<uint8_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT32) {
std::unique_ptr<int32_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT32) {
std::unique_ptr<int64_t> data;
std::unique_ptr<uint32_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT64) {
std::unique_ptr<int64_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_FLOAT32) {
std::unique_ptr<float> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_FLOAT64) {
std::unique_ptr<double> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_STRING) {
auto buffer = tensor->GetStringsBuffer();
std::string ss(reinterpret_cast<const char *>(buffer)); // assume scalar string tensor
(*row_raw_data)[column_name] = std::move(ss);
continue;
} else {
RETURN_STATUS_UNEXPECTED("Got unexpected type when casting data.");
}
RETURN_IF_NOT_OK(s);
if (data_ptr != nullptr) {
(*row_bin_data)[column_name] = std::move(data_ptr);
}
}
return Status::OK();
}
template <typename T, typename S>
Status DEPipeline::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert) {
if (nullptr == src) {
RETURN_STATUS_UNEXPECTED("Error: buffer of Tensor is NULL.");
}
*data_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(T));
if (need_convert) {
auto tmp_ptr = std::make_unique<std::vector<uint8_t>>(num_of_elements * sizeof(S));
std::copy(src, src + sizeof(S) * num_of_elements, tmp_ptr->begin());
auto s_ptr = reinterpret_cast<S *>(&(*(tmp_ptr->begin())));
auto el = std::make_unique<T>();
for (uint32_t i = 0; i < num_of_elements; ++i) {
*el = *(s_ptr + i);
auto t_ptr = reinterpret_cast<uint8_t *>(el.get());
for (uint32_t j = 0; j < sizeof(T); ++j) {
*((*data_ptr)->begin() + i * sizeof(T) + j) = *(t_ptr + j);
}
}
} else {
std::copy(src, src + sizeof(T) * num_of_elements, (*data_ptr)->begin());
}
if (shape.empty()) {
*data = std::make_unique<T>();
auto t_ptr = reinterpret_cast<uint8_t *>((*data).get());
for (uint32_t i = 0; i < sizeof(T); ++i) {
*(t_ptr + i) = *((*data_ptr)->begin() + i);
}
}
return Status::OK();
}
Status DEPipeline::FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
const TensorRow &row, json *schema, std::vector<std::string> *index_fields) {
if (schema == nullptr) {
RETURN_STATUS_UNEXPECTED("error: schema is NULL.");
}
if (index_fields == nullptr) {
RETURN_STATUS_UNEXPECTED("error: index fields is NULL.");
}
if (column_name_id_map.empty()) {
RETURN_STATUS_UNEXPECTED("Error: column not found.");
}
for (auto &col : column_name_id_map) {
auto idx = col.second;
auto column_name = col.first;
auto &tensor = row[idx];
auto column_type = tensor->type();
auto column_shape = tensor->shape();
std::string mr_type;
auto shapes = column_shape.AsVector();
std::vector<int> mr_shape(shapes.begin(), shapes.end());
std::string el = column_type.ToString();
if (mindrecord::kTypesMap.find(el) == mindrecord::kTypesMap.end()) {
std::string err_msg("Error: can not support data type: " + el);
RETURN_STATUS_UNEXPECTED(err_msg);
} else {
mr_type = mindrecord::kTypesMap.at(el);
}
if (mr_shape.empty()) {
if (mr_type == "bytes") { // map to int32 when bytes without shape.
mr_type == "int32";
}
(*schema)[column_name] = {{"type", mr_type}};
} else {
if (mr_type == "string") { // mindrecord can not support string with shape.
std::string err_msg("Error: mindrecord can not support multi-dimensional string tensor.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (mr_type == "bytes") { // ignore shape of bytes in minrecord
(*schema)[column_name] = {{"type", mr_type}};
} else {
(*schema)[column_name] = {{"type", mr_type}, {"shape", mr_shape}};
}
}
if (mr_type == "bytes" || !mr_shape.empty()) continue;
index_fields->emplace_back(column_name); // candidate of index fields
}
return Status::OK();
}
Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded) {
......
......@@ -17,6 +17,7 @@
#define DATASET_API_DE_PIPELINE_H_
#include <iostream>
#include <map>
#include <memory>
#include <stack>
#include <string>
......@@ -33,6 +34,7 @@
namespace py = pybind11;
namespace mindspore {
namespace dataset {
using json = nlohmann::json;
using DsOpPtr = std::shared_ptr<DatasetOp>;
class CacheClient;
......@@ -100,6 +102,8 @@ class DEPipeline {
Status GetOutputTypes(py::list *output);
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
int GetDatasetSize() const;
int GetBatchSize() const;
......@@ -110,6 +114,18 @@ class DEPipeline {
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
template <typename T, typename S>
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert = false);
Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
const TensorRow &row, json *schema, std::vector<std::string> *index_fields);
Status FetchDataFromTensorRow(const TensorRow &row,
const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data,
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
Status BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded);
......
......@@ -184,7 +184,11 @@ void bindDEPipeline(py::module *m) {
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
.def("GetBatchSize", &DEPipeline::GetBatchSize)
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount);
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;
});
}
void bindDatasetOps(py::module *m) {
(void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
......
......@@ -312,6 +312,11 @@ class Tensor {
// @return const unsigned char*
const unsigned char *GetBuffer() const;
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
// tensor's type is a string, otherwise undefined address would be returned.
// @return address of the first string of the tensor.
uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; }
// Getter of the type
// @return
DataType type() const { return type_; }
......@@ -643,11 +648,6 @@ class Tensor {
// @return length of the string
Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const;
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
// tensor's type is a string, otherwise undefined address would be returned.
// @return address of the first string of the tensor.
uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; }
// all access to shape_ should be via shape
TensorShape shape_;
// data type of tensor
......
......@@ -215,7 +215,7 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\n Dataset file : ";
out << "\nDataset file : ";
for (auto &file : dataset_file_) {
out << file << " ";
}
......
......@@ -137,6 +137,10 @@ const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "
// number field list
const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"};
const std::unordered_map<std::string, std::string> kTypesMap = {
{"bool", "int32"}, {"int8", "int32"}, {"uint8", "bytes"}, {"int16", "int32"},
{"uint16", "int32"}, {"int32", "int32"}, {"uint32", "int64"}, {"int64", "int64"},
{"float16", "float32"}, {"float32", "float32"}, {"float64", "float64"}, {"string", "string"}};
/// \brief split a string using a character
/// \param[in] field target string
/// \param[in] separator a character for spliting
......
......@@ -124,6 +124,10 @@ class ShardHeader {
MSRStatus FileToPages(const std::string dump_file_name);
static MSRStatus initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id);
private:
MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);
......
......@@ -57,6 +57,8 @@ class ShardIndexGenerator {
/// \brief create databases for indexes
MSRStatus WriteToDatabase();
static MSRStatus finalize(const std::vector<std::string> file_names);
private:
static int Callback(void *not_used, int argc, char **argv, char **az_col_name);
......
......@@ -108,6 +108,13 @@ class ShardWriter {
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
bool parallel_writer = false);
MSRStatus MergeBlobData(const std::vector<string> &blob_fields,
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
std::shared_ptr<std::vector<uint8_t>> *output);
static MSRStatus initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
const std::vector<std::string> &file_names);
private:
/// \brief write shard header data to disk
MSRStatus WriteShardHeader();
......
......@@ -622,5 +622,21 @@ void ShardIndexGenerator::DatabaseWriter() {
shard_no = task_++;
}
}
MSRStatus ShardIndexGenerator::finalize(const std::vector<std::string> file_names) {
if (file_names.empty()) {
MS_LOG(ERROR) << "Mindrecord files is empty.";
return FAILED;
}
ShardIndexGenerator sg{file_names[0]};
if (SUCCESS != sg.Build()) {
MS_LOG(ERROR) << "Failed to build index generator.";
return FAILED;
}
if (SUCCESS != sg.WriteToDatabase()) {
MS_LOG(ERROR) << "Failed to write to database.";
return FAILED;
}
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore
......@@ -637,6 +637,42 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
*row_count = std::get<2>(v);
return SUCCESS;
}
MSRStatus ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
std::shared_ptr<std::vector<uint8_t>> *output) {
if (blob_fields.empty()) {
return SUCCESS;
}
if (blob_fields.size() == 1) {
auto &blob = row_bin_data.at(blob_fields[0]);
auto blob_size = blob->size();
*output = std::make_shared<std::vector<uint8_t>>(blob_size);
std::copy(blob->begin(), blob->end(), (*output)->begin());
} else {
size_t output_size = 0;
for (auto &field : blob_fields) {
output_size += row_bin_data.at(field)->size();
}
output_size += blob_fields.size() * sizeof(uint64_t);
*output = std::make_shared<std::vector<uint8_t>>(output_size);
std::vector<uint8_t> buf(sizeof(uint64_t), 0);
size_t idx = 0;
for (auto &field : blob_fields) {
auto &blob = row_bin_data.at(field);
uint64_t blob_size = blob->size();
// big edian
for (size_t i = 0; i < buf.size(); ++i) {
buf[buf.size() - 1 - i] = std::numeric_limits<uint8_t>::max() & blob_size;
blob_size >>= 8u;
}
std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
idx += buf.size();
std::copy(blob->begin(), blob->end(), (*output)->begin() + idx);
idx += blob->size();
}
}
return SUCCESS;
}
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
......@@ -1250,5 +1286,21 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &la
last_blob_page = page.first;
}
}
MSRStatus ShardWriter::initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
const std::vector<std::string> &file_names) {
if (nullptr == writer_ptr) {
MS_LOG(ERROR) << "ShardWriter pointer is NULL.";
return FAILED;
}
auto res = (*writer_ptr)->Open(file_names, false);
if (SUCCESS != res) {
MS_LOG(ERROR) << "Failed to open mindrecord files to writer.";
return FAILED;
}
(*writer_ptr)->SetHeaderSize(1 << 24);
(*writer_ptr)->SetPageSize(1 << 25);
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore
......@@ -721,5 +721,35 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
page_in_handle.close();
return SUCCESS;
}
MSRStatus ShardHeader::initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id) {
if (nullptr == header_ptr) {
MS_LOG(ERROR) << "ShardHeader pointer is NULL.";
return FAILED;
}
auto schema_ptr = Schema::Build("mindrecord", schema);
if (nullptr == schema_ptr) {
MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema.";
return FAILED;
}
schema_id = (*header_ptr)->AddSchema(schema_ptr);
// create index
std::vector<std::pair<uint64_t, std::string>> id_index_fields;
if (!index_fields.empty()) {
for (auto &el : index_fields) {
id_index_fields.emplace_back(schema_id, el);
}
if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) {
MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index.";
return FAILED;
}
}
auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
blob_fields = build_schema_ptr->GetBlobFields();
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore
......@@ -38,13 +38,13 @@ from mindspore._c_expression import typing
from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
......@@ -1044,6 +1044,34 @@ class Dataset:
return TransferDataset(self, queue_name, device_id, device_type, num_batch)
@check_save
def save(self, file_name, num_files=1, file_type='mindrecord'):
"""
Save the dynamic data processed by dataset pipeline as common dataset format, support: mindrecord.
Note:
1. To save the samples in order, should set dataset's shuffle false and num_files 1.
2. Before call the function, do not use batch, repeat operator or data augmentation operators
with random attribute in map operator.
3. Mindreocrd do not support np.uint64, multi-dimensional np.uint8(drop dimension) and
multi-dimensional string.
Args:
file_name (str): Path to dataset file.
num_files (int, optional): Number of dataset files.(default=1).
file_type (str, optional): dataset format.(default='mindrecord')
"""
if num_files == 1:
file_names = [file_name]
else:
suffix = len(str(num_files - 1))
file_names = ["{}{}".format(file_name, str(x).rjust(suffix, '0'))
for x in range(num_files)]
return SaveOp(self).save(file_names, file_type)
def create_tuple_iterator(self, columns=None):
"""
Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
......
......@@ -173,6 +173,7 @@ class Iterator:
# Convert python node into C node and add to C layer execution tree in postorder traversal.
def __convert_node_postorder(self, node):
self.check_node_type(node)
op_type = self.__get_dataset_type(node)
c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
......@@ -224,6 +225,10 @@ class Iterator:
self._index += 1
return data
@abstractmethod
def check_node_type(self, node):
pass
def get_output_shapes(self):
return [t for t in self.depipeline.GetOutputShapes()]
......@@ -245,11 +250,27 @@ class Iterator:
def __deepcopy__(self, memo):
return self
class SaveOp(Iterator):
"""
The derived class of Iterator with dict type.
"""
def get_next(self):
pass
def check_node_type(self, node):
if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
logger.warning("Used shuffle, repeat, batch before save operator.")
def save(self, file_names, file_type):
return self.depipeline.SaveDataset(file_names, file_type)
class DictIterator(Iterator):
"""
The derived class of Iterator with dict type.
"""
def check_node_type(self, node):
pass
def __iter__(self):
return self
......@@ -269,6 +290,8 @@ class TupleIterator(Iterator):
"""
The derived class of Iterator with list type.
"""
def check_node_type(self, node):
pass
def __init__(self, dataset, columns=None):
if columns is not None:
......
......@@ -246,7 +246,24 @@ def check_celebadataset(method):
return new_method
def check_save(method):
"""A wrapper that wrap a parameter checker to the save op."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_files']
nreq_param_str = ['file_name', 'file_type']
validate_dataset_param_value(nreq_param_int, param_dict, int)
if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
raise ValueError("num_files should between {} and {}.".format(1, 1000))
validate_dataset_param_value(nreq_param_str, param_dict, str)
if param_dict.get('file_type') != 'mindrecord':
raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
return method(self, *args, **kwargs)
return new_method
def check_minddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This is the test module for saveOp.
"""
import os
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
import numpy as np
import pytest
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
FILES_NUM = 1
num_readers = 1
@pytest.fixture(name="add_and_remove_cv_file")
def fixture_remove():
"""add/remove cv file"""
if os.path.exists("{}".format(CV_FILE_NAME1)):
os.remove("{}".format(CV_FILE_NAME1))
if os.path.exists("{}.db".format(CV_FILE_NAME1)):
os.remove("{}.db".format(CV_FILE_NAME1))
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
yield "yield_cv_data"
if os.path.exists("{}".format(CV_FILE_NAME1)):
os.remove("{}".format(CV_FILE_NAME1))
if os.path.exists("{}.db".format(CV_FILE_NAME1)):
os.remove("{}.db".format(CV_FILE_NAME1))
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
def test_case_00(add_and_remove_cv_file): # only bin data
data = [{"image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'),
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
"image5": bytes("image1 bytes mno", encoding='UTF-8')},
{"image1": bytes("image2 bytes abc", encoding='UTF-8'),
"image2": bytes("image2 bytes def", encoding='UTF-8'),
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
"image5": bytes("image2 bytes mno", encoding='UTF-8')},
{"image1": bytes("image3 bytes abc", encoding='UTF-8'),
"image2": bytes("image3 bytes def", encoding='UTF-8'),
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
"image5": bytes("image3 bytes mno", encoding='UTF-8')},
{"image1": bytes("image5 bytes abc", encoding='UTF-8'),
"image2": bytes("image5 bytes def", encoding='UTF-8'),
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
"image5": bytes("image5 bytes mno", encoding='UTF-8')},
{"image1": bytes("image6 bytes abc", encoding='UTF-8'),
"image2": bytes("image6 bytes def", encoding='UTF-8'),
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
"image5": bytes("image6 bytes mno", encoding='UTF-8')}]
schema = {
"image1": {"type": "bytes"},
"image2": {"type": "bytes"},
"image3": {"type": "bytes"},
"image4": {"type": "bytes"},
"image5": {"type": "bytes"}}
writer = FileWriter(CV_FILE_NAME1, FILES_NUM)
writer.add_schema(schema, "schema")
writer.write_raw_data(data)
writer.commit()
d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False)
d1.save(CV_FILE_NAME2, FILES_NUM)
data_value_to_list = []
for item in data:
new_data = {}
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
data_value_to_list.append(new_data)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
assert d2.get_dataset_size() == 5
num_iter = 0
for item in d2.create_dict_iterator():
assert len(item) == 5
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 5
def test_case_01(add_and_remove_cv_file): # only raw data
data = [{"file_name": "001.jpg", "label": 43},
{"file_name": "002.jpg", "label": 91},
{"file_name": "003.jpg", "label": 61},
{"file_name": "004.jpg", "label": 29},
{"file_name": "005.jpg", "label": 78},
{"file_name": "006.jpg", "label": 37}]
schema = {"file_name": {"type": "string"},
"label": {"type": "int32"}
}
writer = FileWriter(CV_FILE_NAME1, FILES_NUM)
writer.add_schema(schema, "schema")
writer.write_raw_data(data)
writer.commit()
d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False)
d1.save(CV_FILE_NAME2, FILES_NUM)
data_value_to_list = []
for item in data:
new_data = {}
new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
data_value_to_list.append(new_data)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
assert d2.get_dataset_size() == 6
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(item)
assert len(item) == 2
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
def test_case_02(add_and_remove_cv_file): # muti-bytes
data = [{"file_name": "001.jpg", "label": 43,
"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12345,
"float64": 1987654321.123456785,
"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int32),
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'),
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
"image5": bytes("image1 bytes mno", encoding='UTF-8')},
{"file_name": "002.jpg", "label": 91,
"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12445,
"float64": 1987654321.123456786,
"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int32),
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
"image2": bytes("image2 bytes def", encoding='UTF-8'),
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
"image5": bytes("image2 bytes mno", encoding='UTF-8')},
{"file_name": "003.jpg", "label": 61,
"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12545,
"float64": 1987654321.123456787,
"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int32),
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
"image2": bytes("image3 bytes def", encoding='UTF-8'),
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
"image5": bytes("image3 bytes mno", encoding='UTF-8')},
{"file_name": "004.jpg", "label": 29,
"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12645,
"float64": 1987654321.123456788,
"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int32),
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
"image2": bytes("image4 bytes def", encoding='UTF-8'),
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
"image5": bytes("image4 bytes mno", encoding='UTF-8')},
{"file_name": "005.jpg", "label": 78,
"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12745,
"float64": 1987654321.123456789,
"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int32),
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
"image2": bytes("image5 bytes def", encoding='UTF-8'),
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
"image5": bytes("image5 bytes mno", encoding='UTF-8')},
{"file_name": "006.jpg", "label": 37,
"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
"float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
123414314.2141243, 87.1212122], dtype=np.float64),
"float32": 3456.12745,
"float64": 1987654321.123456789,
"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int32),
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
"image2": bytes("image6 bytes def", encoding='UTF-8'),
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
"image5": bytes("image6 bytes mno", encoding='UTF-8')}
]
schema = {"file_name": {"type": "string"},
"float32_array": {"type": "float32", "shape": [-1]},
"float64_array": {"type": "float64", "shape": [-1]},
"float32": {"type": "float32"},
"float64": {"type": "float64"},
"source_sos_ids": {"type": "int32", "shape": [-1]},
"source_sos_mask": {"type": "int64", "shape": [-1]},
"image1": {"type": "bytes"},
"image2": {"type": "bytes"},
"image3": {"type": "bytes"},
"label": {"type": "int32"},
"image4": {"type": "bytes"},
"image5": {"type": "bytes"}}
writer = FileWriter(CV_FILE_NAME1, FILES_NUM)
writer.add_schema(schema, "schema")
writer.write_raw_data(data)
writer.commit()
d1 = ds.MindDataset(CV_FILE_NAME1, None, num_readers, shuffle=False)
d1.save(CV_FILE_NAME2, FILES_NUM)
data_value_to_list = []
for item in data:
new_data = {}
new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
new_data['float32_array'] = item["float32_array"]
new_data['float64_array'] = item["float64_array"]
new_data['float32'] = item["float32"]
new_data['float64'] = item["float64"]
new_data['source_sos_ids'] = item["source_sos_ids"]
new_data['source_sos_mask'] = item["source_sos_mask"]
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
data_value_to_list.append(new_data)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
assert d2.get_dataset_size() == 6
num_iter = 0
for item in d2.create_dict_iterator():
assert len(item) == 13
for field in item:
if isinstance(item[field], np.ndarray):
if item[field].dtype == np.float32:
assert (item[field] ==
np.array(data_value_to_list[num_iter][field], np.float32)).all()
else:
assert (item[field] ==
data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
def generator_1d():
for i in range(10):
yield (np.array([i]),)
def test_case_03(add_and_remove_cv_file):
# apply dataset operations
d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
d1.save(CV_FILE_NAME2)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
i = 0
for item in d2.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
def generator_with_type(t):
for i in range(64):
yield (np.array([i], dtype=t),)
def type_tester(t):
logger.info("Test with Type {}".format(t.__name__))
# apply dataset operations
data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], shuffle=False)
data1 = data1.batch(4)
data1 = data1.repeat(3)
data1.save(CV_FILE_NAME2)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
i = 0
num_repeat = 0
for item in d2.create_dict_iterator(): # each data is a dictionary
golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
logger.info(item)
assert np.array_equal(item["data"], golden)
i = i + 4
if i == 64:
i = 0
num_repeat += 1
assert num_repeat == 3
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
def test_case_04():
# uint8 will drop shape as mindrecord store uint8 as bytes
types = [np.int8, np.int16, np.int32, np.int64,
np.uint16, np.uint32, np.float32, np.float64]
for t in types:
type_tester(t)
def test_case_05(add_and_remove_cv_file):
d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
with pytest.raises(Exception, match="num_files should between 1 and 1000."):
d1.save(CV_FILE_NAME2, 0)
def test_case_06(add_and_remove_cv_file):
d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
d1.save(CV_FILE_NAME2, 1, "tfrecord")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册