提交 decf12cd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1317 [MD]add compress for nlp data in mindrecord

Merge pull request !1317 from liyong126/mindrecord_compress
...@@ -112,25 +112,26 @@ Status MindRecordOp::Init() { ...@@ -112,25 +112,26 @@ Status MindRecordOp::Init() {
data_schema_ = std::make_unique<DataSchema>(); data_schema_ = std::make_unique<DataSchema>();
std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas(); std::vector<std::string> col_names = shard_reader_->get_shard_column()->GetColumnName();
// check whether schema exists, if so use the first one CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found");
CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found"); std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->get_shard_column()->GeColumnDataType();
mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"]; std::vector<std::vector<int64_t>> col_shapes = shard_reader_->get_shard_column()->GetColumnShape();
bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything
std::map<std::string, int32_t> colname_to_ind; std::map<std::string, int32_t> colname_to_ind;
for (mindrecord::json::iterator it = mr_schema.begin(); it != mr_schema.end(); ++it) { for (uint32_t i = 0; i < col_names.size(); i++) {
std::string colname = it.key(); // key of the json, column name std::string colname = col_names[i];
mindrecord::json it_value = it.value(); // value, which contains type info and may contain shape
ColDescriptor col_desc; ColDescriptor col_desc;
TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown
std::string type_str = (it_value["type"] == "bytes" || it_value["type"] == "string") ? "uint8" : it_value["type"]; std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]];
DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"}
if (it_value["type"] == "bytes") { // rank = 1
if (col_data_types[i] == mindrecord::ColumnBytes || col_data_types[i] == mindrecord::ColumnString) { // rank = 1
col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1);
} else if (it_value.find("shape") != it_value.end()) { } else if (col_shapes[i].size() > 0) {
std::vector<dsize_t> vec(it_value["shape"].size()); // temporary vector to hold shape std::vector<dsize_t> vec(col_shapes[i].size()); // temporary vector to hold shape
(void)std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); (void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin());
t_shape = TensorShape(vec); t_shape = TensorShape(vec);
col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape);
} else { // unknown shape } else { // unknown shape
...@@ -162,30 +163,7 @@ Status MindRecordOp::Init() { ...@@ -162,30 +163,7 @@ Status MindRecordOp::Init() {
num_rows_ = shard_reader_->GetNumRows(); num_rows_ = shard_reader_->GetNumRows();
// Compute how many buffers we would need to accomplish rowsPerBuffer // Compute how many buffers we would need to accomplish rowsPerBuffer
buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_;
RETURN_IF_NOT_OK(SetColumnsBlob());
return Status::OK();
}
Status MindRecordOp::SetColumnsBlob() {
columns_blob_ = shard_reader_->GetBlobFields().second;
// get the exactly blob fields by columns_to_load_
std::vector<std::string> columns_blob_exact;
for (auto &blob_field : columns_blob_) {
for (auto &column : columns_to_load_) {
if (column.compare(blob_field) == 0) {
columns_blob_exact.push_back(blob_field);
break;
}
}
}
columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1);
int32_t iBlob = 0;
for (auto &blob_exact : columns_blob_exact) {
columns_blob_index_[column_name_id_map_[blob_exact]] = iBlob++;
}
return Status::OK(); return Status::OK();
} }
...@@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { ...@@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
} }
} }
template <typename T>
Status MindRecordOp::LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col,
const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const {
TensorShape new_shape = TensorShape::CreateUnknownRankShape();
const unsigned char *data = nullptr;
std::unique_ptr<T[]> array_data;
std::string string_data;
const ColDescriptor &cur_column = data_schema_->column(i_col);
std::string column_name = columns_to_load_[i_col];
DataType type = cur_column.type();
// load blob column
if (columns_blob_index_[i_col] >= 0 && columns_blob.size() > 0) {
int32_t pos = columns_blob_.size() == 1 ? -1 : columns_blob_index_[i_col];
RETURN_IF_NOT_OK(LoadBlob(&new_shape, &data, columns_blob, pos, cur_column));
} else {
switch (type.value()) {
case DataType::DE_UINT8: {
// For strings (Assume DE_UINT8 is reserved for strings)
RETURN_IF_NOT_OK(LoadByte(&new_shape, &string_data, column_name, columns_json));
data = reinterpret_cast<const unsigned char *>(common::SafeCStr(string_data));
break;
}
case DataType::DE_FLOAT32: {
// For both float scalars and arrays
RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, false));
data = reinterpret_cast<const unsigned char *>(array_data.get());
break;
}
case DataType::DE_FLOAT64: {
// For both double scalars and arrays
RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, true));
data = reinterpret_cast<const unsigned char *>(array_data.get());
break;
}
default: {
// For both integers scalars and arrays
RETURN_IF_NOT_OK(LoadInt(&new_shape, &array_data, column_name, columns_json, cur_column));
data = reinterpret_cast<const unsigned char *>(array_data.get());
break;
}
}
}
// Create Tensor with given details
RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, cur_column.tensorImpl(), new_shape, type, data));
return Status::OK();
}
Status MindRecordOp::LoadBlob(TensorShape *new_shape, const unsigned char **data,
const std::vector<uint8_t> &columns_blob, const int32_t pos,
const ColDescriptor &column) {
const auto kColumnSize = column.type().SizeInBytes();
if (kColumnSize == 0) {
RETURN_STATUS_UNEXPECTED("column size is null");
}
if (pos == -1) {
if (column.hasShape()) {
*new_shape = TensorShape::CreateUnknownRankShape();
RETURN_IF_NOT_OK(
column.MaterializeTensorShape(static_cast<int32_t>(columns_blob.size() / kColumnSize), new_shape));
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_blob.size() / kColumnSize)};
*new_shape = TensorShape(shapeDetails);
}
*data = reinterpret_cast<const uint8_t *>(&(columns_blob[0]));
return Status::OK();
}
auto uint64_from_bytes = [&](int64_t pos) {
uint64_t result = 0;
for (uint64_t n = 0; n < kInt64Len; n++) {
result = (result << 8) + columns_blob[pos + n];
}
return result;
};
uint64_t iStart = 0;
for (int32_t i = 0; i < pos; i++) {
uint64_t num_bytes = uint64_from_bytes(iStart);
iStart += kInt64Len + num_bytes;
}
uint64_t num_bytes = uint64_from_bytes(iStart);
iStart += kInt64Len;
if (column.hasShape()) {
*new_shape = TensorShape::CreateUnknownRankShape();
RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_bytes / kColumnSize), new_shape));
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_bytes / kColumnSize)};
*new_shape = TensorShape(shapeDetails);
}
*data = reinterpret_cast<const uint8_t *>(&(columns_blob[iStart]));
return Status::OK();
}
template <typename T>
Status MindRecordOp::LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double) {
if (!columns_json[column_name].is_array()) {
T value = 0;
RETURN_IF_NOT_OK(GetFloat(&value, columns_json[column_name], use_double));
*new_shape = TensorShape::CreateScalar();
*array_data = std::make_unique<T[]>(1);
(*array_data)[0] = value;
} else {
if (column.hasShape()) {
*new_shape = TensorShape(column.shape());
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())};
*new_shape = TensorShape(shapeDetails);
}
int idx = 0;
*array_data = std::make_unique<T[]>(new_shape->NumOfElements());
for (auto &element : columns_json[column_name]) {
T value = 0;
RETURN_IF_NOT_OK(GetFloat(&value, element, use_double));
(*array_data)[idx++] = value;
}
}
return Status::OK();
}
template <typename T>
Status MindRecordOp::GetFloat(T *value, const mindrecord::json &data, bool use_double) {
if (data.is_number()) {
*value = data;
} else if (data.is_string()) {
try {
if (use_double) {
*value = data.get<double>();
} else {
*value = data.get<float>();
}
} catch (mindrecord::json::exception &e) {
RETURN_STATUS_UNEXPECTED("Conversion to float failed.");
}
} else {
RETURN_STATUS_UNEXPECTED("Conversion to float failed.");
}
return Status::OK();
}
template <typename T>
Status MindRecordOp::LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column) {
if (!columns_json[column_name].is_array()) {
T value = 0;
RETURN_IF_NOT_OK(GetInt(&value, columns_json[column_name]));
*new_shape = TensorShape::CreateScalar();
*array_data = std::make_unique<T[]>(1);
(*array_data)[0] = value;
} else {
if (column.hasShape()) {
*new_shape = TensorShape(column.shape());
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())};
*new_shape = TensorShape(shapeDetails);
}
int idx = 0;
*array_data = std::make_unique<T[]>(new_shape->NumOfElements());
for (auto &element : columns_json[column_name]) {
T value = 0;
RETURN_IF_NOT_OK(GetInt(&value, element));
(*array_data)[idx++] = value;
}
}
return Status::OK();
}
template <typename T>
Status MindRecordOp::GetInt(T *value, const mindrecord::json &data) {
int64_t temp_value = 0;
bool less_than_zero = false;
if (data.is_number_integer()) {
const mindrecord::json json_zero = 0;
if (data < json_zero) less_than_zero = true;
temp_value = data;
} else if (data.is_string()) {
std::string string_value = data;
if (!string_value.empty() && string_value[0] == '-') {
try {
temp_value = std::stoll(string_value);
less_than_zero = true;
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument.");
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
}
} else {
try {
temp_value = static_cast<int64_t>(std::stoull(string_value));
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument.");
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
}
}
} else {
RETURN_STATUS_UNEXPECTED("Conversion to int failed.");
}
if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed. Out of range");
}
*value = static_cast<T>(temp_value);
return Status::OK();
}
Status MindRecordOp::LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name,
const mindrecord::json &columns_json) {
*string_data = columns_json[column_name];
std::vector<dsize_t> shape_details = {static_cast<dsize_t>(string_data->size())};
*new_shape = TensorShape(shape_details);
return Status::OK();
}
Status MindRecordOp::WorkerEntry(int32_t worker_id) { Status MindRecordOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
std::unique_ptr<IOBlock> io_block; std::unique_ptr<IOBlock> io_block;
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
while (io_block != nullptr) { while (io_block != nullptr) {
if (io_block->eoe() == true) { if (io_block->eoe()) {
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
continue; continue;
} }
if (io_block->eof() == true) { if (io_block->eof()) {
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
...@@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu ...@@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
if (tupled_buffer.empty()) break; if (tupled_buffer.empty()) break;
} }
for (const auto &tupled_row : tupled_buffer) { for (const auto &tupled_row : tupled_buffer) {
std::vector<uint8_t> columnsBlob = std::get<0>(tupled_row); std::vector<uint8_t> columns_blob = std::get<0>(tupled_row);
mindrecord::json columns_json = std::get<1>(tupled_row); mindrecord::json columns_json = std::get<1>(tupled_row);
TensorRow tensor_row; TensorRow tensor_row;
for (uint32_t j = 0; j < columns_to_load_.size(); ++j) { RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json));
std::shared_ptr<Tensor> tensor;
const ColDescriptor &cur_column = data_schema_->column(j);
DataType type = cur_column.type();
RETURN_IF_NOT_OK(SwitchLoadFeature(type, &tensor, j, columnsBlob, columns_json));
tensor_row.push_back(std::move(tensor));
}
tensor_table->push_back(std::move(tensor_row)); tensor_table->push_back(std::move(tensor_row));
} }
} }
...@@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu ...@@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
return Status::OK(); return Status::OK();
} }
Status MindRecordOp::SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col, Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob,
const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) {
const mindrecord::json &columns_json) const { for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) {
switch (type.value()) { auto column_name = columns_to_load_[i_col];
case DataType::DE_BOOL: {
return LoadFeature<bool>(tensor, i_col, columns_blob, columns_json); // Initialize column parameters
} const unsigned char *data = nullptr;
case DataType::DE_INT8: { std::unique_ptr<unsigned char[]> data_ptr;
return LoadFeature<int8_t>(tensor, i_col, columns_blob, columns_json); uint64_t n_bytes = 0;
} mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType;
case DataType::DE_UINT8: { uint64_t column_data_type_size = 1;
return LoadFeature<uint8_t>(tensor, i_col, columns_blob, columns_json); std::vector<int64_t> column_shape;
}
case DataType::DE_INT16: { // Get column data
return LoadFeature<int16_t>(tensor, i_col, columns_blob, columns_json);
} auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName(
case DataType::DE_UINT16: { column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size,
return LoadFeature<uint16_t>(tensor, i_col, columns_blob, columns_json); &column_shape);
} if (has_column == MSRStatus::FAILED) {
case DataType::DE_INT32: { RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader.");
return LoadFeature<int32_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_UINT32: {
return LoadFeature<uint32_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_INT64: {
return LoadFeature<int64_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_UINT64: {
return LoadFeature<uint64_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_FLOAT32: {
return LoadFeature<float>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_FLOAT64: {
return LoadFeature<double>(tensor, i_col, columns_blob, columns_json);
} }
default: {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, std::shared_ptr<Tensor> tensor;
"mindrecord column list type does not match any known types"); const ColDescriptor &column = data_schema_->column(i_col);
DataType type = column.type();
// Set shape
auto num_elements = n_bytes / column_data_type_size;
if (column.hasShape()) {
auto new_shape = TensorShape(column.shape());
RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_elements), &new_shape));
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data));
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_elements)};
auto new_shape = TensorShape(shapeDetails);
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data));
} }
tensor_row->push_back(std::move(tensor));
} }
return Status::OK();
} }
Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <queue> #include <queue>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -31,6 +32,7 @@ ...@@ -31,6 +32,7 @@
#include "dataset/engine/datasetops/source/io_block.h" #include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/queue.h" #include "dataset/util/queue.h"
#include "dataset/util/status.h" #include "dataset/util/status.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h" #include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_reader.h" #include "mindrecord/include/shard_reader.h"
#include "mindrecord/include/common/shard_utils.h" #include "mindrecord/include/common/shard_utils.h"
...@@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp { ...@@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp {
Status Init(); Status Init();
Status SetColumnsBlob();
// Base-class override for NodePass visitor acceptor. // Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted. // @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline. // @param modified - Whether this node visit modified the pipeline.
...@@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp { ...@@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp {
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id); Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
// Parses a single cell and puts the data into a tensor // Parses a single cell and puts the data into a tensor
// @param tensor - the tensor to put the parsed data in // @param tensor_row - the tensor row to put the parsed data in
// @param i_col - the id of column to parse
// @param columns_blob - the blob data received from the reader // @param columns_blob - the blob data received from the reader
// @param columns_json - the data for fields received from the reader // @param columns_json - the data for fields received from the reader
template <typename T> Status LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob,
Status LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col, const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json);
const mindrecord::json &columns_json) const;
Status SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col,
const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const;
static Status LoadBlob(TensorShape *new_shape, const unsigned char **data, const std::vector<uint8_t> &columns_blob,
const int32_t pos, const ColDescriptor &column);
// Get shape and data (scalar or array) for tensor to be created (for floats and doubles)
// @param new_shape - the shape of tensor to be created.
// @param array_data - the array where data should be put in
// @param column_name - name of current column to be processed
// @param columns_json - the data for fields received from the reader
// @param column - description of current column from schema
// @param use_double - boolean to choose between float32 and float64
template <typename T>
static Status LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double);
// Get shape and data (scalar or array) for tensor to be created (for integers)
// @param new_shape - the shape of tensor to be created.
// @param array_data - the array where data should be put in
// @param column_name - name of current column to be processed
// @param columns_json - the data for fields received from the reader
// @param column - description of current column from schema
template <typename T>
static Status LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column);
static Status LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name,
const mindrecord::json &columns_json);
// Get a single float value from the given json
// @param value - the float to put the value in
// @param arrayData - the given json containing the float
// @param use_double - boolean to choose between float32 and float64
template <typename T>
static Status GetFloat(T *value, const mindrecord::json &data, bool use_double);
// Get a single integer value from the given json
// @param value - the integer to put the value in
// @param arrayData - the given json containing the integer
template <typename T>
static Status GetInt(T *value, const mindrecord::json &data);
Status FetchBlockBuffer(const int32_t &buffer_id); Status FetchBlockBuffer(const int32_t &buffer_id);
......
...@@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) { ...@@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) {
.def("launch", &ShardReader::Launch) .def("launch", &ShardReader::Launch)
.def("get_header", &ShardReader::GetShardHeader) .def("get_header", &ShardReader::GetShardHeader)
.def("get_blob_fields", &ShardReader::GetBlobFields) .def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next", .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
(std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) ShardReader::GetNextPy)
.def("finish", &ShardReader::Finish) .def("finish", &ShardReader::Finish)
.def("close", &ShardReader::Close); .def("close", &ShardReader::Close);
} }
......
...@@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4; ...@@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4;
enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel };
const char kVersion[] = "3.0";
const std::vector<std::string> kSupportedVersion = {"2.0", kVersion};
enum ShardType { enum ShardType {
kNLP = 0, kNLP = 0,
kCV = 1, kCV = 1,
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_
#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_header.h"
namespace mindspore {
namespace mindrecord {
const uint64_t kUnsignedOne = 1;
const uint64_t kBitsOfByte = 8;
const uint64_t kDataTypeBits = 2;
const uint64_t kNumDataOfByte = 4;
const uint64_t kBytesOfColumnLen = 4;
const uint64_t kDataTypeBitMask = 3;
const uint64_t kDataTypes = 6;
enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type };
enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound };
enum ColumnDataType {
ColumnBytes = 0,
ColumnString = 1,
ColumnInt32 = 2,
ColumnInt64 = 3,
ColumnFloat32 = 4,
ColumnFloat64 = 5,
ColumnNoDataType = 6
};
// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"};
const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8};
const std::vector<std::string> ColumnDataTypeNameNormalized = {"uint8", "uint8", "int32",
"int64", "float32", "float64"};
const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {
{"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32},
{"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}};
class ShardColumn {
public:
explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true);
~ShardColumn() = default;
/// \brief get column value by column name
MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);
/// \brief compress blob
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob);
/// \brief check if blob compressed
bool CheckCompressBlob() const { return has_compress_blob_; }
uint64_t GetNumBlobColumn() const { return num_blob_column_; }
std::vector<std::string> GetColumnName() { return column_name_; }
std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; }
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }
/// \brief get column value from blob
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *n_bytes);
private:
/// \brief get column value from json
MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
/// \brief get float value from json
template <typename T>
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
/// \brief get integer value from json
template <typename T>
MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value);
/// \brief get column offset address and size from blob
MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx);
/// \brief check if column name is available
ColumnCategory CheckColumnName(const std::string &column_name);
/// \brief compress integer column
static vector<uint8_t> CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type);
/// \brief uncompress integer array column
template <typename T>
static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx);
/// \brief convert big-endian bytes to unsigned int
/// \param bytes_array bytes array
/// \param pos shift address in bytes array
/// \param i_type integer type
/// \return unsigned int
static uint64_t BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &i_type);
/// \brief convert unsigned int to big-endian bytes
/// \param value integer value
/// \param i_type integer type
/// \return bytes
static std::vector<uint8_t> UIntToBytesBig(uint64_t value, const IntegerType &i_type);
/// \brief convert unsigned int to little-endian bytes
/// \param value integer value
/// \param i_type integer type
/// \return bytes
static std::vector<uint8_t> UIntToBytesLittle(uint64_t value, const IntegerType &i_type);
/// \brief convert unsigned int to little-endian bytes
/// \param bytes_array bytes array
/// \param pos shift address in bytes array
/// \param src_i_type source integer typ0e
/// \param dst_i_type (output), destination integer type
/// \return integer
static int64_t BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr);
private:
std::vector<std::string> column_name_; // column name list
std::vector<ColumnDataType> column_data_type_; // column data type list
std::vector<std::vector<int64_t>> column_shape_; // column shape list
std::unordered_map<string, uint64_t> column_name_id_; // column name id map
std::vector<std::string> blob_column_; // blob column list
std::unordered_map<std::string, uint64_t> blob_column_id_; // blob column name id map
bool has_compress_blob_; // if has compress blob
uint64_t num_blob_column_; // number of blob columns
};
} // namespace mindrecord
} // namespace mindspore
#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_
...@@ -118,8 +118,6 @@ class ShardHeader { ...@@ -118,8 +118,6 @@ class ShardHeader {
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
const string GetVersion() { return version_; }
std::vector<std::string> SerializeHeader(); std::vector<std::string> SerializeHeader();
MSRStatus PagesToFile(const std::string dump_file_name); MSRStatus PagesToFile(const std::string dump_file_name);
...@@ -175,7 +173,6 @@ class ShardHeader { ...@@ -175,7 +173,6 @@ class ShardHeader {
uint32_t shard_count_; uint32_t shard_count_;
uint64_t header_size_; uint64_t header_size_;
uint64_t page_size_; uint64_t page_size_;
string version_ = "2.0";
std::shared_ptr<Index> index_; std::shared_ptr<Index> index_;
std::vector<std::string> shard_addresses_; std::vector<std::string> shard_addresses_;
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include <vector> #include <vector>
#include "mindrecord/include/common/shard_utils.h" #include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h" #include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_index_generator.h" #include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_operator.h" #include "mindrecord/include/shard_operator.h"
...@@ -111,6 +112,10 @@ class ShardReader { ...@@ -111,6 +112,10 @@ class ShardReader {
/// \return the metadata /// \return the metadata
std::shared_ptr<ShardHeader> GetShardHeader() const; std::shared_ptr<ShardHeader> GetShardHeader() const;
/// \brief aim to get columns context
/// \return the columns
std::shared_ptr<ShardColumn> get_shard_column() const;
/// \brief get the number of shards /// \brief get the number of shards
/// \return # of shards /// \return # of shards
int GetShardCount() const; int GetShardCount() const;
...@@ -185,7 +190,7 @@ class ShardReader { ...@@ -185,7 +190,7 @@ class ShardReader {
/// \brief return a batch, given that one is ready, python API /// \brief return a batch, given that one is ready, python API
/// \return a batch of images and image data /// \return a batch of images and image data
std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> GetNextPy(); std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy();
/// \brief get blob filed list /// \brief get blob filed list
/// \return blob field list /// \return blob field list
...@@ -295,16 +300,18 @@ class ShardReader { ...@@ -295,16 +300,18 @@ class ShardReader {
/// \brief get number of classes /// \brief get number of classes
int64_t GetNumClasses(const std::string &category_field); int64_t GetNumClasses(const std::string &category_field);
/// \brief get meta of header
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data); std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
/// \brief get exactly blob fields data by indices
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes, /// \brief extract uncompressed data based on column list
std::vector<uint32_t> &ordered_selected_columns_index); std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
protected: protected:
uint64_t header_size_; // header size uint64_t header_size_; // header size
uint64_t page_size_; // page size uint64_t page_size_; // page size
int shard_count_; // number of shards int shard_count_; // number of shards
std::shared_ptr<ShardHeader> shard_header_; // shard header std::shared_ptr<ShardHeader> shard_header_; // shard header
std::shared_ptr<ShardColumn> shard_column_; // shard column
std::vector<sqlite3 *> database_paths_; // sqlite handle list std::vector<sqlite3 *> database_paths_; // sqlite handle list
std::vector<string> file_paths_; // file paths std::vector<string> file_paths_; // file paths
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "mindrecord/include/common/shard_utils.h" #include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h" #include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_header.h" #include "mindrecord/include/shard_header.h"
#include "mindrecord/include/shard_index.h" #include "mindrecord/include/shard_index.h"
...@@ -242,7 +243,8 @@ class ShardWriter { ...@@ -242,7 +243,8 @@ class ShardWriter {
std::vector<std::string> file_paths_; // file paths std::vector<std::string> file_paths_; // file paths
std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles
std::shared_ptr<ShardHeader> shard_header_; // shard headers std::shared_ptr<ShardHeader> shard_header_; // shard header
std::shared_ptr<ShardColumn> shard_column_; // shard columns
std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info
......
...@@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa ...@@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
shard_header_ = std::make_shared<ShardHeader>(sh); shard_header_ = std::make_shared<ShardHeader>(sh);
header_size_ = shard_header_->GetHeaderSize(); header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize(); page_size_ = shard_header_->GetPageSize();
// version < 3.0
if (first_meta_data["version"] < kVersion) {
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
} else {
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
}
num_rows_ = 0; num_rows_ = 0;
auto row_group_summary = ReadRowGroupSummary(); auto row_group_summary = ReadRowGroupSummary();
for (const auto &rg : row_group_summary) { for (const auto &rg : row_group_summary) {
...@@ -226,6 +232,8 @@ void ShardReader::Close() { ...@@ -226,6 +232,8 @@ void ShardReader::Close() {
std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; } std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
std::shared_ptr<ShardColumn> ShardReader::get_shard_column() const { return shard_column_; }
int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
int ShardReader::GetNumRows() const { return num_rows_; } int ShardReader::GetNumRows() const { return num_rows_; }
...@@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u ...@@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
return SUCCESS; return SUCCESS;
} }
std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns(
std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) {
std::vector<uint8_t> exactly_blob_fields_bytes;
auto uint64_from_bytes = [&](int64_t pos) {
uint64_t result = 0;
for (uint64_t n = 0; n < kInt64Len; n++) {
result = (result << 8) + blob_fields_bytes[pos + n];
}
return result;
};
// get the exactly blob fields
uint32_t current_index = 0;
uint64_t current_offset = 0;
uint64_t data_len = uint64_from_bytes(current_offset);
while (current_offset < blob_fields_bytes.size()) {
if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(),
[&current_index](uint32_t &index) { return index == current_index; })) {
exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset,
blob_fields_bytes.begin() + current_offset + kInt64Len + data_len);
}
current_index++;
current_offset += kInt64Len + data_len;
data_len = uint64_from_bytes(current_offset);
}
return exactly_blob_fields_bytes;
}
TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) {
// All tasks are done // All tasks are done
if (task_id >= static_cast<int>(tasks_.Size())) { if (task_id >= static_cast<int>(tasks_.Size())) {
...@@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ ...@@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>()); return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
} }
// extract the exactly blob bytes by selected columns
std::vector<uint8_t> images_with_exact_columns;
if (selected_columns_.size() == 0) {
images_with_exact_columns = images;
} else {
auto blob_fields = GetBlobFields();
std::vector<uint32_t> ordered_selected_columns_index;
uint32_t index = 0;
for (auto &blob_field : blob_fields.second) {
for (auto &field : selected_columns_) {
if (field.compare(blob_field) == 0) {
ordered_selected_columns_index.push_back(index);
break;
}
}
index++;
}
if (ordered_selected_columns_index.size() != 0) {
// extract the images
if (blob_fields.second.size() == 1) {
if (ordered_selected_columns_index.size() == 1) {
images_with_exact_columns = images;
}
} else {
images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index);
}
}
}
// Deliver batch data to output map // Deliver batch data to output map
std::vector<std::tuple<std::vector<uint8_t>, json>> batch; std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task))); batch.emplace_back(std::move(images), std::move(std::get<2>(task)));
return std::make_pair(SUCCESS, std::move(batch)); return std::make_pair(SUCCESS, std::move(batch));
} }
...@@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con ...@@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con
return std::move(ret.second); return std::move(ret.second);
} }
std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> ShardReader::GetNextPy() { std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardReader::UnCompressBlob(
const std::vector<uint8_t> &raw_blob_data) {
auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
auto blob_fields = GetBlobFields().second;
std::vector<std::vector<uint8_t>> blob_data;
for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue;
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0;
auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << ".";
return {FAILED, std::vector<std::vector<uint8_t>>(blob_fields.size(), std::vector<uint8_t>())};
}
if (data == nullptr) {
data = reinterpret_cast<const unsigned char *>(data_ptr.get());
}
std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
blob_data.push_back(column);
}
return {SUCCESS, blob_data};
}
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> ShardReader::GetNextPy() {
auto res = GetNext(); auto res = GetNext();
vector<std::tuple<std::vector<uint8_t>, pybind11::object>> jsonData; vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> data;
std::transform(res.begin(), res.end(), std::back_inserter(jsonData), std::transform(res.begin(), res.end(), std::back_inserter(data),
[](const std::tuple<std::vector<uint8_t>, json> &item) { [this](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item); auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j); pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj)); auto ret = UnCompressBlob(std::get<0>(item));
return std::make_tuple(ret.second, std::move(obj));
}); });
return jsonData; return data;
} }
void ShardReader::Reset() { void ShardReader::Reset() {
......
...@@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { ...@@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
MS_LOG(ERROR) << "Open file failed"; MS_LOG(ERROR) << "Open file failed";
return FAILED; return FAILED;
} }
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
return SUCCESS; return SUCCESS;
} }
...@@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) ...@@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
shard_header_ = header_data; shard_header_ = header_data;
shard_header_->SetHeaderSize(header_size_); shard_header_->SetHeaderSize(header_size_);
shard_header_->SetPageSize(page_size_); shard_header_->SetPageSize(page_size_);
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
return SUCCESS; return SUCCESS;
} }
...@@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json> ...@@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
MS_LOG(ERROR) << "IO error / there is no free disk to be used"; MS_LOG(ERROR) << "IO error / there is no free disk to be used";
return FAILED; return FAILED;
} }
// compress blob
if (shard_column_->CheckCompressBlob()) {
for (auto &blob : blob_data) {
blob = shard_column_->CompressBlob(blob);
}
}
// Add 4-bytes dummy blob data if no any blob fields // Add 4-bytes dummy blob data if no any blob fields
if (blob_data.size() == 0 && raw_data.size() > 0) { if (blob_data.size() == 0 && raw_data.size() > 0) {
blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0)); blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindrecord/include/shard_column.h"
#include "common/utils.h"
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_error.h"
namespace mindspore {
namespace mindrecord {
ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) {
auto first_schema = shard_header->GetSchemas()[0];
auto schema = first_schema->GetSchema()["schema"];
bool has_integer_array = false;
for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
const std::string &column_name = it.key();
column_name_.push_back(column_name);
json it_value = it.value();
std::string str_type = it_value["type"];
column_data_type_.push_back(ColumnDataTypeMap.at(str_type));
if (it_value.find("shape") != it_value.end()) {
std::vector<int64_t> vec(it_value["shape"].size());
std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin());
column_shape_.push_back(vec);
if (str_type == "int32" || str_type == "int64") {
has_integer_array = true;
}
} else {
std::vector<int64_t> vec = {};
column_shape_.push_back(vec);
}
}
for (uint64_t i = 0; i < column_name_.size(); i++) {
column_name_id_[column_name_[i]] = i;
}
auto blob_fields = first_schema->GetBlobFields();
for (const auto &field : blob_fields) {
blob_column_.push_back(field);
}
for (uint64_t i = 0; i < blob_column_.size(); i++) {
blob_column_id_[blob_column_[i]] = i;
}
has_compress_blob_ = (compress_integer && has_integer_array);
num_blob_column_ = blob_column_.size();
}
MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape) {
// Skip if column not found
auto column_category = CheckColumnName(column_name);
if (column_category == ColumnNotFound) {
return FAILED;
}
// Get data type and size
auto column_id = column_name_id_[column_name];
*column_data_type = column_data_type_[column_id];
*column_data_type_size = ColumnDataTypeSize[*column_data_type];
*column_shape = column_shape_[column_id];
// Retrieve value from json
if (column_category == ColumnInRaw) {
if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) {
MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << ".";
return FAILED;
}
*data = reinterpret_cast<const unsigned char *>(data_ptr->get());
return SUCCESS;
}
// Retrieve value from blob
if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) {
MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << ".";
return FAILED;
}
if (*data == nullptr) {
*data = reinterpret_cast<const unsigned char *>(data_ptr->get());
}
return SUCCESS;
}
MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) {
auto column_id = column_name_id_[column_name];
auto column_data_type = column_data_type_[column_id];
// Initialize num bytes
*n_bytes = ColumnDataTypeSize[column_data_type];
auto json_column_value = columns_json[column_name];
switch (column_data_type) {
case ColumnFloat32: {
return GetFloat<float>(data_ptr, json_column_value, false);
}
case ColumnFloat64: {
return GetFloat<double>(data_ptr, json_column_value, true);
}
case ColumnInt32: {
return GetInt<int32_t>(data_ptr, json_column_value);
}
case ColumnInt64: {
return GetInt<int64_t>(data_ptr, json_column_value);
}
default: {
// Convert string to c_str
std::string tmp_string = json_column_value;
*n_bytes = tmp_string.size();
auto data = reinterpret_cast<const unsigned char *>(common::SafeCStr(tmp_string));
*data_ptr = std::make_unique<unsigned char[]>(*n_bytes);
for (uint32_t i = 0; i < *n_bytes; i++) {
(*data_ptr)[i] = *(data + i);
}
break;
}
}
return SUCCESS;
}
template <typename T>
MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value,
bool use_double) {
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
if (!json_column_value.is_string() && !json_column_value.is_number()) {
MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ").";
return FAILED;
}
if (json_column_value.is_number()) {
array_data[0] = json_column_value;
} else {
// Convert string to float
try {
if (use_double) {
array_data[0] = json_column_value.get<double>();
} else {
array_data[0] = json_column_value.get<float>();
}
} catch (json::exception &e) {
MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ").";
return FAILED;
}
}
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
*data_ptr = std::make_unique<unsigned char[]>(sizeof(T));
for (uint32_t i = 0; i < sizeof(T); i++) {
(*data_ptr)[i] = *(data + i);
}
return SUCCESS;
}
template <typename T>
MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) {
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
int64_t temp_value;
bool less_than_zero = false;
if (json_column_value.is_number_integer()) {
const json json_zero = 0;
if (json_column_value < json_zero) less_than_zero = true;
temp_value = json_column_value;
} else if (json_column_value.is_string()) {
std::string string_value = json_column_value;
if (!string_value.empty() && string_value[0] == '-') {
try {
temp_value = std::stoll(string_value);
less_than_zero = true;
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
return FAILED;
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
return FAILED;
}
} else {
try {
temp_value = static_cast<int64_t>(std::stoull(string_value));
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
return FAILED;
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
return FAILED;
}
}
} else {
MS_LOG(ERROR) << "Conversion to int failed.";
return FAILED;
}
if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
MS_LOG(ERROR) << "Conversion to int failed. Out of range";
return FAILED;
}
array_data[0] = static_cast<T>(temp_value);
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
*data_ptr = std::make_unique<unsigned char[]>(sizeof(T));
for (uint32_t i = 0; i < sizeof(T); i++) {
(*data_ptr)[i] = *(data + i);
}
return SUCCESS;
}
MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *n_bytes) {
uint64_t offset_address = 0;
auto column_id = column_name_id_[column_name];
if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) {
return FAILED;
}
auto column_data_type = column_data_type_[column_id];
if (has_compress_blob_ && column_data_type == ColumnInt32) {
if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
return FAILED;
}
} else if (has_compress_blob_ && column_data_type == ColumnInt64) {
if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
return FAILED;
}
} else {
*data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address]));
}
return SUCCESS;
}
ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
auto it_column = column_name_id_.find(column_name);
if (it_column == column_name_id_.end()) {
return ColumnNotFound;
}
auto it_blob = blob_column_id_.find(column_name);
return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob;
}
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) {
// Skip if no compress columns
if (!CheckCompressBlob()) return blob;
std::vector<uint8_t> dst_blob;
uint64_t i_src = 0;
for (int64_t i = 0; i < num_blob_column_; i++) {
// Get column data type
auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]];
auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type;
// Compress and return is blob has 1 column only
if (num_blob_column_ == 1) {
return CompressInt(blob, int_type);
}
// Just copy and continue if column dat type is not int32/int64
uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type);
if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) {
dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes);
i_src += kInt64Len + num_bytes;
continue;
}
// Get column slice in source blob
std::vector<uint8_t> blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes);
// Compress column
auto dst_blob_slice = CompressInt(blob_slice, int_type);
// Get new column size
auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type);
// Append new colmn size
dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end());
// Append new colmn data
dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end());
i_src += kInt64Len + num_bytes;
}
MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << ".";
return dst_blob;
}
vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
uint64_t i_size = kUnsignedOne << int_type;
// Get number of elements
uint64_t src_n_int = src_bytes.size() / i_size;
// Calculate bitmap size (bytes)
uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte;
// Initilize destination blob, more space than needed, will be resized
vector<uint8_t> dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0);
// Write number of elements to destination blob
vector<uint8_t> size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type);
for (uint64_t n = 0; n < kBytesOfColumnLen; n++) {
dst_bytes[n] = size_by_bytes[n];
}
// Write compressed int
uint64_t i_dst = kBytesOfColumnLen + bitmap_size;
for (uint64_t i = 0; i < src_n_int; i++) {
// Initialize destination data type
IntegerType dst_int_type = kInt8Type;
// Shift to next int position
uint64_t pos = i * (kUnsignedOne << int_type);
// Narrow down this int
int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);
// Write this int to destination blob
uint64_t u_n = *reinterpret_cast<uint64_t *>(&i_n);
auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type);
for (uint64_t j = 0; j < (kUnsignedOne << dst_int_type); j++) {
dst_bytes[i_dst++] = temp_bytes[j];
}
// Update date type in bit map
dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |=
(dst_int_type << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte))));
}
// Resize destination blob
dst_bytes.resize(i_dst);
MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << ".";
return dst_bytes;
}
MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx) {
if (num_blob_column_ == 1) {
*num_bytes = columns_blob.size();
*shift_idx = 0;
return SUCCESS;
}
auto blob_id = blob_column_id_[column_name_[column_id]];
for (int32_t i = 0; i < blob_id; i++) {
*shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type);
}
*num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type);
(*shift_idx) += kInt64Len;
return SUCCESS;
}
template <typename T>
MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes,
uint64_t shift_idx) {
auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type);
*num_bytes = sizeof(T) * num_elements;
// Parse integer array
uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte;
auto array_data = std::make_unique<T[]>(num_elements);
for (uint64_t i = 0; i < num_elements; i++) {
uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte];
uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask;
auto mr_int_type = static_cast<IntegerType>(i_type);
int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type);
i_source += (kUnsignedOne << i_type);
array_data[i] = static_cast<T>(i64);
}
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
*data_ptr = std::make_unique<unsigned char[]>(*num_bytes);
memcpy(data_ptr->get(), data, *num_bytes);
return SUCCESS;
}
uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &i_type) {
uint64_t result = 0;
for (uint64_t i = 0; i < (kUnsignedOne << i_type); i++) {
result = (result << kBitsOfByte) + bytes_array[pos + i];
}
return result;
}
std::vector<uint8_t> ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) {
uint64_t n_bytes = kUnsignedOne << i_type;
std::vector<uint8_t> result(n_bytes, 0);
for (uint64_t i = 0; i < n_bytes; i++) {
result[n_bytes - 1 - i] = value & std::numeric_limits<uint8_t>::max();
value >>= kBitsOfByte;
}
return result;
}
std::vector<uint8_t> ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) {
uint64_t n_bytes = kUnsignedOne << i_type;
std::vector<uint8_t> result(n_bytes, 0);
for (uint64_t i = 0; i < n_bytes; i++) {
result[i] = value & std::numeric_limits<uint8_t>::max();
value >>= kBitsOfByte;
}
return result;
}
int64_t ShardColumn::BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &src_i_type, IntegerType *dst_i_type) {
uint64_t u_temp = 0;
for (uint64_t i = 0; i < (kUnsignedOne << src_i_type); i++) {
u_temp = (u_temp << kBitsOfByte) + bytes_array[pos + (kUnsignedOne << src_i_type) - kUnsignedOne - i];
}
int64_t i_out;
switch (src_i_type) {
case kInt8Type: {
i_out = (int8_t)(u_temp & std::numeric_limits<uint8_t>::max());
break;
}
case kInt16Type: {
i_out = (int16_t)(u_temp & std::numeric_limits<uint16_t>::max());
break;
}
case kInt32Type: {
i_out = (int32_t)(u_temp & std::numeric_limits<uint32_t>::max());
break;
}
case kInt64Type: {
i_out = (int64_t)(u_temp & std::numeric_limits<uint64_t>::max());
break;
}
default: {
i_out = 0;
}
}
if (!dst_i_type) {
return i_out;
}
if (i_out >= static_cast<int64_t>(std::numeric_limits<int8_t>::min()) &&
i_out <= static_cast<int64_t>(std::numeric_limits<int8_t>::max())) {
*dst_i_type = kInt8Type;
} else if (i_out >= static_cast<int64_t>(std::numeric_limits<int16_t>::min()) &&
i_out <= static_cast<int64_t>(std::numeric_limits<int16_t>::max())) {
*dst_i_type = kInt16Type;
} else if (i_out >= static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
i_out <= static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
*dst_i_type = kInt32Type;
} else {
*dst_i_type = kInt64Type;
}
return i_out;
}
} // namespace mindrecord
} // namespace mindspore
...@@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade ...@@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
json header; json header;
header = ret.second; header = ret.second;
header["shard_addresses"] = realAddresses; header["shard_addresses"] = realAddresses;
if (header["version"] != version_) { if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) {
MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump()
<< ", lib version is: " << version_; << ", lib version is: " << kVersion;
thread_status = true; thread_status = true;
return; return;
} }
...@@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() { ...@@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
s += "\"shard_addresses\":" + address + ","; s += "\"shard_addresses\":" + address + ",";
s += "\"shard_id\":" + std::to_string(shardId) + ","; s += "\"shard_id\":" + std::to_string(shardId) + ",";
s += "\"statistics\":" + stats + ","; s += "\"statistics\":" + stats + ",";
s += "\"version\":\"" + version_ + "\""; s += "\"version\":\"" + std::string(kVersion) + "\"";
s += "}"; s += "}";
header.emplace_back(s); header.emplace_back(s);
} }
......
...@@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema): ...@@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema):
if not blob_fields: if not blob_fields:
return raw return raw
# Get the order preserving sequence of columns in blob loaded_columns = []
ordered_columns = []
if columns: if columns:
for blob_field in blob_fields: for column in columns:
if blob_field in columns: if column in blob_fields:
ordered_columns.append(blob_field) loaded_columns.append(column)
else: else:
ordered_columns = blob_fields loaded_columns = blob_fields
blob_bytes = bytes(blob)
def _render_raw(field, blob_data): def _render_raw(field, blob_data):
data_type = schema[field]['type'] data_type = schema[field]['type']
...@@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema): ...@@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
else: else:
raw[field] = blob_data raw[field] = blob_data
if len(blob_fields) == 1: for i, blob_field in enumerate(loaded_columns):
if len(ordered_columns) == 1: _render_raw(blob_field, bytes(blob[i]))
_render_raw(blob_fields[0], blob_bytes)
return raw
return raw
def _int_from_bytes(xbytes: bytes) -> int:
return int.from_bytes(xbytes, 'big')
def _blob_at_position(pos):
start = 0
for _ in range(pos):
n_bytes = _int_from_bytes(blob_bytes[start : start + 8])
start += 8 + n_bytes
n_bytes = _int_from_bytes(blob_bytes[start : start + 8])
start += 8
return blob_bytes[start : start + n_bytes]
for i, blob_field in enumerate(ordered_columns):
_render_raw(blob_field, _blob_at_position(i))
return raw return raw
...@@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS ...@@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS
CIFAR100_DIR = "../data/mindrecord/testCifar100Data" CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
MINDRECORD_FILE = "./cifar100.mindrecord" MINDRECORD_FILE = "./cifar100.mindrecord"
@pytest.fixture
def test_cifar100_to_mindrecord_without_index_fields(): def fixture_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
remove_file(MINDRECORD_FILE)
yield "yield_fixture_data"
remove_file(MINDRECORD_FILE)
def test_cifar100_to_mindrecord_without_index_fields(fixture_file):
"""test transform cifar100 dataset to mindrecord without index fields.""" """test transform cifar100 dataset to mindrecord without index fields."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
ret = cifar100_transformer.transform() ret = cifar100_transformer.transform()
...@@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields(): ...@@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields():
assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test") assert os.path.exists(MINDRECORD_FILE + "_test")
read() read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar100_to_mindrecord(): def test_cifar100_to_mindrecord(fixture_file):
"""test transform cifar100 dataset to mindrecord.""" """test transform cifar100 dataset to mindrecord."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
cifar100_transformer.transform(['fine_label', 'coarse_label']) cifar100_transformer.transform(['fine_label', 'coarse_label'])
assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test") assert os.path.exists(MINDRECORD_FILE + "_test")
read() read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def read(): def read():
...@@ -77,8 +82,7 @@ def read(): ...@@ -77,8 +82,7 @@ def read():
assert count == 4 assert count == 4
reader.close() reader.close()
def test_cifar100_to_mindrecord_illegal_file_name(fixture_file):
def test_cifar100_to_mindrecord_illegal_file_name():
""" """
test transform cifar100 dataset to mindrecord test transform cifar100 dataset to mindrecord
when file name contains illegal character. when file name contains illegal character.
...@@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name(): ...@@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name():
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
cifar100_transformer.transform() cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file):
def test_cifar100_to_mindrecord_filename_start_with_space():
""" """
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when file name starts with space. when file name starts with space.
...@@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space(): ...@@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space():
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename) cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
cifar100_transformer.transform() cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_contain_space(fixture_file):
def test_cifar100_to_mindrecord_filename_contain_space():
""" """
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when file name contains space. when file name contains space.
...@@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space(): ...@@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space():
cifar100_transformer.transform() cifar100_transformer.transform()
assert os.path.exists(filename) assert os.path.exists(filename)
assert os.path.exists(filename + "_test") assert os.path.exists(filename + "_test")
os.remove("{}".format(filename))
os.remove("{}.db".format(filename))
os.remove("{}".format(filename + "_test"))
os.remove("{}.db".format(filename + "_test"))
def test_cifar100_to_mindrecord_directory(fixture_file):
def test_cifar100_to_mindrecord_directory():
""" """
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when destination path is directory. when destination path is directory.
...@@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory(): ...@@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory():
CIFAR100_DIR) CIFAR100_DIR)
cifar100_transformer.transform() cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file):
def test_cifar100_to_mindrecord_filename_equals_cifar100():
""" """
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when destination path equals source path. when destination path equals source path.
......
...@@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS ...@@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS
CIFAR10_DIR = "../data/mindrecord/testCifar10Data" CIFAR10_DIR = "../data/mindrecord/testCifar10Data"
MINDRECORD_FILE = "./cifar10.mindrecord" MINDRECORD_FILE = "./cifar10.mindrecord"
@pytest.fixture
def test_cifar10_to_mindrecord_without_index_fields(): def fixture_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
remove_file(MINDRECORD_FILE)
yield "yield_fixture_data"
remove_file(MINDRECORD_FILE)
@pytest.fixture
def fixture_space_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
x = "./yes ok"
remove_file(x)
yield "yield_fixture_data"
remove_file(x)
def test_cifar10_to_mindrecord_without_index_fields(fixture_file):
"""test transform cifar10 dataset to mindrecord without index fields.""" """test transform cifar10 dataset to mindrecord without index fields."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
cifar10_transformer.transform() cifar10_transformer.transform()
assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test") assert os.path.exists(MINDRECORD_FILE + "_test")
read() read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar10_to_mindrecord(): def test_cifar10_to_mindrecord(fixture_file):
"""test transform cifar10 dataset to mindrecord.""" """test transform cifar10 dataset to mindrecord."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
cifar10_transformer.transform(['label']) cifar10_transformer.transform(['label'])
assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test") assert os.path.exists(MINDRECORD_FILE + "_test")
read() read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar10_to_mindrecord_with_return(): def test_cifar10_to_mindrecord_with_return(fixture_file):
"""test transform cifar10 dataset to mindrecord.""" """test transform cifar10 dataset to mindrecord."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE) cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
ret = cifar10_transformer.transform(['label']) ret = cifar10_transformer.transform(['label'])
...@@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return(): ...@@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return():
assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test") assert os.path.exists(MINDRECORD_FILE + "_test")
read() read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def read(): def read():
...@@ -90,8 +109,7 @@ def read(): ...@@ -90,8 +109,7 @@ def read():
assert count == 4 assert count == 4
reader.close() reader.close()
def test_cifar10_to_mindrecord_illegal_file_name(fixture_file):
def test_cifar10_to_mindrecord_illegal_file_name():
""" """
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when file name contains illegal character. when file name contains illegal character.
...@@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name(): ...@@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name():
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
cifar10_transformer.transform() cifar10_transformer.transform()
def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file):
def test_cifar10_to_mindrecord_filename_start_with_space():
""" """
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when file name starts with space. when file name starts with space.
...@@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space(): ...@@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space():
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename) cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
cifar10_transformer.transform() cifar10_transformer.transform()
def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file):
def test_cifar10_to_mindrecord_filename_contain_space():
""" """
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when file name contains space. when file name contains space.
...@@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space(): ...@@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space():
cifar10_transformer.transform() cifar10_transformer.transform()
assert os.path.exists(filename) assert os.path.exists(filename)
assert os.path.exists(filename + "_test") assert os.path.exists(filename + "_test")
os.remove("{}".format(filename))
os.remove("{}.db".format(filename))
os.remove("{}".format(filename + "_test"))
os.remove("{}.db".format(filename + "_test"))
def test_cifar10_to_mindrecord_directory(): def test_cifar10_to_mindrecord_directory(fixture_file):
""" """
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when destination path is directory. when destination path is directory.
......
...@@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images" ...@@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images"
MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord" MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord"
PARTITION_NUMBER = 4 PARTITION_NUMBER = 4
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file():
x = MINDRECORD_FILE
remove_one_file(x)
x = MINDRECORD_FILE + ".db"
remove_one_file(x)
for i in range(PARTITION_NUMBER):
x = MINDRECORD_FILE + str(i)
remove_one_file(x)
x = MINDRECORD_FILE + str(i) + ".db"
remove_one_file(x)
remove_file()
yield "yield_fixture_data"
remove_file()
def read(filename): def read(filename):
"""test file reade""" """test file reade"""
...@@ -38,8 +58,7 @@ def read(filename): ...@@ -38,8 +58,7 @@ def read(filename):
assert count == 20 assert count == 20
reader.close() reader.close()
def test_imagenet_to_mindrecord(fixture_file):
def test_imagenet_to_mindrecord():
"""test transform imagenet dataset to mindrecord.""" """test transform imagenet dataset to mindrecord."""
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR, imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
MINDRECORD_FILE, PARTITION_NUMBER) MINDRECORD_FILE, PARTITION_NUMBER)
...@@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord(): ...@@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord():
assert os.path.exists(MINDRECORD_FILE + str(i)) assert os.path.exists(MINDRECORD_FILE + str(i))
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
read(MINDRECORD_FILE + "0") read(MINDRECORD_FILE + "0")
for i in range(PARTITION_NUMBER):
os.remove(MINDRECORD_FILE + str(i))
os.remove(MINDRECORD_FILE + str(i) + ".db")
def test_imagenet_to_mindrecord_default_partition_number(): def test_imagenet_to_mindrecord_default_partition_number(fixture_file):
""" """
test transform imagenet dataset to mindrecord test transform imagenet dataset to mindrecord
when partition number is default. when partition number is default.
...@@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number(): ...@@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number():
assert os.path.exists(MINDRECORD_FILE) assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + ".db") assert os.path.exists(MINDRECORD_FILE + ".db")
read(MINDRECORD_FILE) read(MINDRECORD_FILE)
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
def test_imagenet_to_mindrecord_partition_number_0(fixture_file):
def test_imagenet_to_mindrecord_partition_number_0():
""" """
test transform imagenet dataset to mindrecord test transform imagenet dataset to mindrecord
when partition number is 0. when partition number is 0.
...@@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0(): ...@@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0():
MINDRECORD_FILE, 0) MINDRECORD_FILE, 0)
imagenet_transformer.transform() imagenet_transformer.transform()
def test_imagenet_to_mindrecord_partition_number_none(fixture_file):
def test_imagenet_to_mindrecord_partition_number_none():
""" """
test transform imagenet dataset to mindrecord test transform imagenet dataset to mindrecord
when partition number is none. when partition number is none.
...@@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none(): ...@@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none():
MINDRECORD_FILE, None) MINDRECORD_FILE, None)
imagenet_transformer.transform() imagenet_transformer.transform()
def test_imagenet_to_mindrecord_illegal_filename(fixture_file):
def test_imagenet_to_mindrecord_illegal_filename():
""" """
test transform imagenet dataset to mindrecord test transform imagenet dataset to mindrecord
when file name contains illegal character. when file name contains illegal character.
......
...@@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord" ...@@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord"
NLP_FILE_NAME = "./aclImdb.mindrecord" NLP_FILE_NAME = "./aclImdb.mindrecord"
FILES_NUM = 4 FILES_NUM = 4
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file(file_name):
x = file_name
remove_one_file(x)
x = file_name + ".db"
remove_one_file(x)
for i in range(FILES_NUM):
x = file_name + str(i)
remove_one_file(x)
x = file_name + str(i) + ".db"
remove_one_file(x)
@pytest.fixture
def fixture_cv_file():
"""add/remove file"""
remove_file(CV_FILE_NAME)
yield "yield_fixture_data"
remove_file(CV_FILE_NAME)
@pytest.fixture
def fixture_nlp_file():
"""add/remove file"""
remove_file(NLP_FILE_NAME)
yield "yield_fixture_data"
remove_file(NLP_FILE_NAME)
def test_cv_file_writer_shard_num_none(): def test_cv_file_writer_shard_num_none():
"""test cv file writer when shard num is None.""" """test cv file writer when shard num is None."""
...@@ -83,8 +111,7 @@ def test_lack_partition_and_db(): ...@@ -83,8 +111,7 @@ def test_lack_partition_and_db():
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
def test_lack_db(fixture_cv_file):
def test_lack_db():
"""test file reader when db file does not exist.""" """test file reader when db file does not exist."""
create_cv_mindrecord(1) create_cv_mindrecord(1)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
...@@ -94,10 +121,8 @@ def test_lack_db(): ...@@ -94,10 +121,8 @@ def test_lack_db():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
os.remove(CV_FILE_NAME)
def test_lack_some_partition_and_db(fixture_cv_file):
def test_lack_some_partition_and_db():
"""test file reader when some partition and db do not exist.""" """test file reader when some partition and db do not exist."""
create_cv_mindrecord(4) create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
...@@ -110,16 +135,8 @@ def test_lack_some_partition_and_db(): ...@@ -110,16 +135,8 @@ def test_lack_some_partition_and_db():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_lack_some_partition_first(): def test_lack_some_partition_first(fixture_cv_file):
"""test file reader when first partition does not exist.""" """test file reader when first partition does not exist."""
create_cv_mindrecord(4) create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
...@@ -131,14 +148,8 @@ def test_lack_some_partition_first(): ...@@ -131,14 +148,8 @@ def test_lack_some_partition_first():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_lack_some_partition_middle(): def test_lack_some_partition_middle(fixture_cv_file):
"""test file reader when some partition does not exist.""" """test file reader when some partition does not exist."""
create_cv_mindrecord(4) create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
...@@ -150,14 +161,8 @@ def test_lack_some_partition_middle(): ...@@ -150,14 +161,8 @@ def test_lack_some_partition_middle():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_lack_some_partition_last(fixture_cv_file):
def test_lack_some_partition_last():
"""test file reader when last partition does not exist.""" """test file reader when last partition does not exist."""
create_cv_mindrecord(4) create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
...@@ -169,14 +174,8 @@ def test_lack_some_partition_last(): ...@@ -169,14 +174,8 @@ def test_lack_some_partition_last():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_mindpage_lack_some_partition(): def test_mindpage_lack_some_partition(fixture_cv_file):
"""test page reader when some partition does not exist.""" """test page reader when some partition does not exist."""
create_cv_mindrecord(4) create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
...@@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition(): ...@@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_lack_some_db(): def test_lack_some_db(fixture_cv_file):
"""test file reader when some db does not exist.""" """test file reader when some db does not exist."""
create_cv_mindrecord(4) create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
...@@ -206,11 +199,6 @@ def test_lack_some_db(): ...@@ -206,11 +199,6 @@ def test_lack_some_db():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_invalid_mindrecord(): def test_invalid_mindrecord():
...@@ -225,8 +213,7 @@ def test_invalid_mindrecord(): ...@@ -225,8 +213,7 @@ def test_invalid_mindrecord():
in str(err.value) in str(err.value)
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
def test_invalid_db(fixture_cv_file):
def test_invalid_db():
"""test file reader when the content of db is illegal.""" """test file reader when the content of db is illegal."""
create_cv_mindrecord(1) create_cv_mindrecord(1)
os.remove("imagenet.mindrecord.db") os.remove("imagenet.mindrecord.db")
...@@ -237,11 +224,8 @@ def test_invalid_db(): ...@@ -237,11 +224,8 @@ def test_invalid_db():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
os.remove("imagenet.mindrecord")
os.remove("imagenet.mindrecord.db")
def test_overwrite_invalid_mindrecord(): def test_overwrite_invalid_mindrecord(fixture_cv_file):
"""test file writer when overwrite invalid mindreocrd file.""" """test file writer when overwrite invalid mindreocrd file."""
with open(CV_FILE_NAME, 'w') as f: with open(CV_FILE_NAME, 'w') as f:
f.write('just for test') f.write('just for test')
...@@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord(): ...@@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord():
assert '[MRMOpenError]: error_code: 1347690596, ' \ assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \ 'error_msg: MindRecord File could not open successfully.' \
in str(err.value) in str(err.value)
os.remove(CV_FILE_NAME)
def test_overwrite_invalid_db(): def test_overwrite_invalid_db(fixture_cv_file):
"""test file writer when overwrite invalid db file.""" """test file writer when overwrite invalid db file."""
with open('imagenet.mindrecord.db', 'w') as f: with open('imagenet.mindrecord.db', 'w') as f:
f.write('just for test') f.write('just for test')
...@@ -261,11 +243,8 @@ def test_overwrite_invalid_db(): ...@@ -261,11 +243,8 @@ def test_overwrite_invalid_db():
create_cv_mindrecord(1) create_cv_mindrecord(1)
assert '[MRMGenerateIndexError]: error_code: 1347690612, ' \ assert '[MRMGenerateIndexError]: error_code: 1347690612, ' \
'error_msg: Failed to generate index.' in str(err.value) 'error_msg: Failed to generate index.' in str(err.value)
os.remove("imagenet.mindrecord")
os.remove("imagenet.mindrecord.db")
def test_read_after_close(): def test_read_after_close(fixture_cv_file):
"""test file reader when close read.""" """test file reader when close read."""
create_cv_mindrecord(1) create_cv_mindrecord(1)
reader = FileReader(CV_FILE_NAME) reader = FileReader(CV_FILE_NAME)
...@@ -275,11 +254,8 @@ def test_read_after_close(): ...@@ -275,11 +254,8 @@ def test_read_after_close():
count = count + 1 count = count + 1
logger.info("#item{}: {}".format(index, x)) logger.info("#item{}: {}".format(index, x))
assert count == 0 assert count == 0
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_file_read_after_read(): def test_file_read_after_read(fixture_cv_file):
"""test file reader when finish read.""" """test file reader when finish read."""
create_cv_mindrecord(1) create_cv_mindrecord(1)
reader = FileReader(CV_FILE_NAME) reader = FileReader(CV_FILE_NAME)
...@@ -295,8 +271,6 @@ def test_file_read_after_read(): ...@@ -295,8 +271,6 @@ def test_file_read_after_read():
cnt = cnt + 1 cnt = cnt + 1
logger.info("#item{}: {}".format(index, x)) logger.info("#item{}: {}".format(index, x))
assert cnt == 0 assert cnt == 0
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_file_writer_shard_num_greater_than_1000(): def test_cv_file_writer_shard_num_greater_than_1000():
...@@ -312,8 +286,7 @@ def test_add_index_without_add_schema(): ...@@ -312,8 +286,7 @@ def test_add_index_without_add_schema():
fw.add_index(["label"]) fw.add_index(["label"])
assert 'Failed to get meta info' in str(err.value) assert 'Failed to get meta info' in str(err.value)
def test_mindpage_pageno_pagesize_not_int(fixture_cv_file):
def test_mindpage_pageno_pagesize_not_int():
"""test page reader when some partition does not exist.""" """test page reader when some partition does not exist."""
create_cv_mindrecord(4) create_cv_mindrecord(4)
reader = MindPage(CV_FILE_NAME + "0") reader = MindPage(CV_FILE_NAME + "0")
...@@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int(): ...@@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int():
with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
reader.read_at_page_by_id(99999, 0, 1) reader.read_at_page_by_id(99999, 0, 1)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_mindpage_filename_not_exist(): def test_mindpage_filename_not_exist(fixture_cv_file):
"""test page reader when some partition does not exist.""" """test page reader when some partition does not exist."""
create_cv_mindrecord(4) create_cv_mindrecord(4)
reader = MindPage(CV_FILE_NAME + "0") reader = MindPage(CV_FILE_NAME + "0")
...@@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist(): ...@@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist():
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""test mnist to mindrecord tool""" """test mnist to mindrecord tool"""
import cv2 import cv2
import gzip import gzip
import pytest
import numpy as np import numpy as np
import os import os
...@@ -27,6 +28,34 @@ PARTITION_NUM = 4 ...@@ -27,6 +28,34 @@ PARTITION_NUM = 4
IMAGE_SIZE = 28 IMAGE_SIZE = 28
NUM_CHANNELS = 1 NUM_CHANNELS = 1
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file():
x = "mnist_train.mindrecord"
remove_one_file(x)
x = "mnist_train.mindrecord.db"
remove_one_file(x)
x = "mnist_test.mindrecord"
remove_one_file(x)
x = "mnist_test.mindrecord.db"
remove_one_file(x)
for i in range(PARTITION_NUM):
x = "mnist_train.mindrecord" + str(i)
remove_one_file(x)
x = "mnist_train.mindrecord" + str(i) + ".db"
remove_one_file(x)
x = "mnist_test.mindrecord" + str(i)
remove_one_file(x)
x = "mnist_test.mindrecord" + str(i) + ".db"
remove_one_file(x)
remove_file()
yield "yield_fixture_data"
remove_file()
def read(train_name, test_name): def read(train_name, test_name):
"""test file reader""" """test file reader"""
...@@ -51,7 +80,7 @@ def read(train_name, test_name): ...@@ -51,7 +80,7 @@ def read(train_name, test_name):
reader.close() reader.close()
def test_mnist_to_mindrecord(): def test_mnist_to_mindrecord(fixture_file):
"""test transform mnist dataset to mindrecord.""" """test transform mnist dataset to mindrecord."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
mnist_transformer.transform() mnist_transformer.transform()
...@@ -60,13 +89,7 @@ def test_mnist_to_mindrecord(): ...@@ -60,13 +89,7 @@ def test_mnist_to_mindrecord():
read("mnist_train.mindrecord", "mnist_test.mindrecord") read("mnist_train.mindrecord", "mnist_test.mindrecord")
os.remove("{}".format("mnist_train.mindrecord")) def test_mnist_to_mindrecord_compare_data(fixture_file):
os.remove("{}.db".format("mnist_train.mindrecord"))
os.remove("{}".format("mnist_test.mindrecord"))
os.remove("{}.db".format("mnist_test.mindrecord"))
def test_mnist_to_mindrecord_compare_data():
"""test transform mnist dataset to mindrecord and compare data.""" """test transform mnist dataset to mindrecord and compare data."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME) mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
mnist_transformer.transform() mnist_transformer.transform()
...@@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data(): ...@@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data():
assert np.array(x['label']) == label assert np.array(x['label']) == label
reader.close() reader.close()
os.remove("{}".format("mnist_train.mindrecord")) def test_mnist_to_mindrecord_multi_partition(fixture_file):
os.remove("{}.db".format("mnist_train.mindrecord"))
os.remove("{}".format("mnist_test.mindrecord"))
os.remove("{}.db".format("mnist_test.mindrecord"))
def test_mnist_to_mindrecord_multi_partition():
"""test transform mnist dataset to multiple mindrecord files.""" """test transform mnist dataset to multiple mindrecord files."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM) mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM)
mnist_transformer.transform() mnist_transformer.transform()
read("mnist_train.mindrecord0", "mnist_test.mindrecord0") read("mnist_train.mindrecord0", "mnist_test.mindrecord0")
for i in range(PARTITION_NUM):
os.remove("{}".format("mnist_train.mindrecord" + str(i)))
os.remove("{}.db".format("mnist_train.mindrecord" + str(i)))
os.remove("{}".format("mnist_test.mindrecord" + str(i)))
os.remove("{}.db".format("mnist_test.mindrecord" + str(i)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册