提交 bb51bb88 编写于 作者: J jonwe 提交者: liyong

add compress in mindrecord

上级 2e3d55ed
......@@ -112,25 +112,26 @@ Status MindRecordOp::Init() {
data_schema_ = std::make_unique<DataSchema>();
std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas();
// check whether schema exists, if so use the first one
CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found");
mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"];
std::vector<std::string> col_names = shard_reader_->get_shard_column()->GetColumnName();
CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found");
std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->get_shard_column()->GeColumnDataType();
std::vector<std::vector<int64_t>> col_shapes = shard_reader_->get_shard_column()->GetColumnShape();
bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything
std::map<std::string, int32_t> colname_to_ind;
for (mindrecord::json::iterator it = mr_schema.begin(); it != mr_schema.end(); ++it) {
std::string colname = it.key(); // key of the json, column name
mindrecord::json it_value = it.value(); // value, which contains type info and may contain shape
for (uint32_t i = 0; i < col_names.size(); i++) {
std::string colname = col_names[i];
ColDescriptor col_desc;
TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown
std::string type_str = (it_value["type"] == "bytes" || it_value["type"] == "string") ? "uint8" : it_value["type"];
std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]];
DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"}
if (it_value["type"] == "bytes") { // rank = 1
if (col_data_types[i] == mindrecord::ColumnBytes || col_data_types[i] == mindrecord::ColumnString) { // rank = 1
col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1);
} else if (it_value.find("shape") != it_value.end()) {
std::vector<dsize_t> vec(it_value["shape"].size()); // temporary vector to hold shape
(void)std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin());
} else if (col_shapes[i].size() > 0) {
std::vector<dsize_t> vec(col_shapes[i].size()); // temporary vector to hold shape
(void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin());
t_shape = TensorShape(vec);
col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape);
} else { // unknown shape
......@@ -162,30 +163,7 @@ Status MindRecordOp::Init() {
num_rows_ = shard_reader_->GetNumRows();
// Compute how many buffers we would need to accomplish rowsPerBuffer
buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_;
RETURN_IF_NOT_OK(SetColumnsBlob());
return Status::OK();
}
Status MindRecordOp::SetColumnsBlob() {
columns_blob_ = shard_reader_->GetBlobFields().second;
// get the exactly blob fields by columns_to_load_
std::vector<std::string> columns_blob_exact;
for (auto &blob_field : columns_blob_) {
for (auto &column : columns_to_load_) {
if (column.compare(blob_field) == 0) {
columns_blob_exact.push_back(blob_field);
break;
}
}
}
columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1);
int32_t iBlob = 0;
for (auto &blob_exact : columns_blob_exact) {
columns_blob_index_[column_name_id_map_[blob_exact]] = iBlob++;
}
return Status::OK();
}
......@@ -215,248 +193,18 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
}
}
template <typename T>
Status MindRecordOp::LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col,
const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const {
TensorShape new_shape = TensorShape::CreateUnknownRankShape();
const unsigned char *data = nullptr;
std::unique_ptr<T[]> array_data;
std::string string_data;
const ColDescriptor &cur_column = data_schema_->column(i_col);
std::string column_name = columns_to_load_[i_col];
DataType type = cur_column.type();
// load blob column
if (columns_blob_index_[i_col] >= 0 && columns_blob.size() > 0) {
int32_t pos = columns_blob_.size() == 1 ? -1 : columns_blob_index_[i_col];
RETURN_IF_NOT_OK(LoadBlob(&new_shape, &data, columns_blob, pos, cur_column));
} else {
switch (type.value()) {
case DataType::DE_UINT8: {
// For strings (Assume DE_UINT8 is reserved for strings)
RETURN_IF_NOT_OK(LoadByte(&new_shape, &string_data, column_name, columns_json));
data = reinterpret_cast<const unsigned char *>(common::SafeCStr(string_data));
break;
}
case DataType::DE_FLOAT32: {
// For both float scalars and arrays
RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, false));
data = reinterpret_cast<const unsigned char *>(array_data.get());
break;
}
case DataType::DE_FLOAT64: {
// For both double scalars and arrays
RETURN_IF_NOT_OK(LoadFloat(&new_shape, &array_data, column_name, columns_json, cur_column, true));
data = reinterpret_cast<const unsigned char *>(array_data.get());
break;
}
default: {
// For both integers scalars and arrays
RETURN_IF_NOT_OK(LoadInt(&new_shape, &array_data, column_name, columns_json, cur_column));
data = reinterpret_cast<const unsigned char *>(array_data.get());
break;
}
}
}
// Create Tensor with given details
RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, cur_column.tensorImpl(), new_shape, type, data));
return Status::OK();
}
Status MindRecordOp::LoadBlob(TensorShape *new_shape, const unsigned char **data,
const std::vector<uint8_t> &columns_blob, const int32_t pos,
const ColDescriptor &column) {
const auto kColumnSize = column.type().SizeInBytes();
if (kColumnSize == 0) {
RETURN_STATUS_UNEXPECTED("column size is null");
}
if (pos == -1) {
if (column.hasShape()) {
*new_shape = TensorShape::CreateUnknownRankShape();
RETURN_IF_NOT_OK(
column.MaterializeTensorShape(static_cast<int32_t>(columns_blob.size() / kColumnSize), new_shape));
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_blob.size() / kColumnSize)};
*new_shape = TensorShape(shapeDetails);
}
*data = reinterpret_cast<const uint8_t *>(&(columns_blob[0]));
return Status::OK();
}
auto uint64_from_bytes = [&](int64_t pos) {
uint64_t result = 0;
for (uint64_t n = 0; n < kInt64Len; n++) {
result = (result << 8) + columns_blob[pos + n];
}
return result;
};
uint64_t iStart = 0;
for (int32_t i = 0; i < pos; i++) {
uint64_t num_bytes = uint64_from_bytes(iStart);
iStart += kInt64Len + num_bytes;
}
uint64_t num_bytes = uint64_from_bytes(iStart);
iStart += kInt64Len;
if (column.hasShape()) {
*new_shape = TensorShape::CreateUnknownRankShape();
RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_bytes / kColumnSize), new_shape));
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_bytes / kColumnSize)};
*new_shape = TensorShape(shapeDetails);
}
*data = reinterpret_cast<const uint8_t *>(&(columns_blob[iStart]));
return Status::OK();
}
template <typename T>
Status MindRecordOp::LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double) {
if (!columns_json[column_name].is_array()) {
T value = 0;
RETURN_IF_NOT_OK(GetFloat(&value, columns_json[column_name], use_double));
*new_shape = TensorShape::CreateScalar();
*array_data = std::make_unique<T[]>(1);
(*array_data)[0] = value;
} else {
if (column.hasShape()) {
*new_shape = TensorShape(column.shape());
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())};
*new_shape = TensorShape(shapeDetails);
}
int idx = 0;
*array_data = std::make_unique<T[]>(new_shape->NumOfElements());
for (auto &element : columns_json[column_name]) {
T value = 0;
RETURN_IF_NOT_OK(GetFloat(&value, element, use_double));
(*array_data)[idx++] = value;
}
}
return Status::OK();
}
template <typename T>
Status MindRecordOp::GetFloat(T *value, const mindrecord::json &data, bool use_double) {
if (data.is_number()) {
*value = data;
} else if (data.is_string()) {
try {
if (use_double) {
*value = data.get<double>();
} else {
*value = data.get<float>();
}
} catch (mindrecord::json::exception &e) {
RETURN_STATUS_UNEXPECTED("Conversion to float failed.");
}
} else {
RETURN_STATUS_UNEXPECTED("Conversion to float failed.");
}
return Status::OK();
}
template <typename T>
Status MindRecordOp::LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column) {
if (!columns_json[column_name].is_array()) {
T value = 0;
RETURN_IF_NOT_OK(GetInt(&value, columns_json[column_name]));
*new_shape = TensorShape::CreateScalar();
*array_data = std::make_unique<T[]>(1);
(*array_data)[0] = value;
} else {
if (column.hasShape()) {
*new_shape = TensorShape(column.shape());
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(columns_json[column_name].size())};
*new_shape = TensorShape(shapeDetails);
}
int idx = 0;
*array_data = std::make_unique<T[]>(new_shape->NumOfElements());
for (auto &element : columns_json[column_name]) {
T value = 0;
RETURN_IF_NOT_OK(GetInt(&value, element));
(*array_data)[idx++] = value;
}
}
return Status::OK();
}
template <typename T>
Status MindRecordOp::GetInt(T *value, const mindrecord::json &data) {
int64_t temp_value = 0;
bool less_than_zero = false;
if (data.is_number_integer()) {
const mindrecord::json json_zero = 0;
if (data < json_zero) less_than_zero = true;
temp_value = data;
} else if (data.is_string()) {
std::string string_value = data;
if (!string_value.empty() && string_value[0] == '-') {
try {
temp_value = std::stoll(string_value);
less_than_zero = true;
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument.");
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
}
} else {
try {
temp_value = static_cast<int64_t>(std::stoull(string_value));
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument.");
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
}
}
} else {
RETURN_STATUS_UNEXPECTED("Conversion to int failed.");
}
if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed. Out of range");
}
*value = static_cast<T>(temp_value);
return Status::OK();
}
Status MindRecordOp::LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name,
const mindrecord::json &columns_json) {
*string_data = columns_json[column_name];
std::vector<dsize_t> shape_details = {static_cast<dsize_t>(string_data->size())};
*new_shape = TensorShape(shape_details);
return Status::OK();
}
Status MindRecordOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<IOBlock> io_block;
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
while (io_block != nullptr) {
if (io_block->eoe() == true) {
if (io_block->eoe()) {
RETURN_IF_NOT_OK(
out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
continue;
}
if (io_block->eof() == true) {
if (io_block->eof()) {
RETURN_IF_NOT_OK(
out_connector_->Add(worker_id, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block));
......@@ -521,19 +269,10 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
if (tupled_buffer.empty()) break;
}
for (const auto &tupled_row : tupled_buffer) {
std::vector<uint8_t> columnsBlob = std::get<0>(tupled_row);
std::vector<uint8_t> columns_blob = std::get<0>(tupled_row);
mindrecord::json columns_json = std::get<1>(tupled_row);
TensorRow tensor_row;
for (uint32_t j = 0; j < columns_to_load_.size(); ++j) {
std::shared_ptr<Tensor> tensor;
const ColDescriptor &cur_column = data_schema_->column(j);
DataType type = cur_column.type();
RETURN_IF_NOT_OK(SwitchLoadFeature(type, &tensor, j, columnsBlob, columns_json));
tensor_row.push_back(std::move(tensor));
}
RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json));
tensor_table->push_back(std::move(tensor_row));
}
}
......@@ -543,48 +282,46 @@ Status MindRecordOp::GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_bu
return Status::OK();
}
Status MindRecordOp::SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col,
const std::vector<uint8_t> &columns_blob,
const mindrecord::json &columns_json) const {
switch (type.value()) {
case DataType::DE_BOOL: {
return LoadFeature<bool>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_INT8: {
return LoadFeature<int8_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_UINT8: {
return LoadFeature<uint8_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_INT16: {
return LoadFeature<int16_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_UINT16: {
return LoadFeature<uint16_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_INT32: {
return LoadFeature<int32_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_UINT32: {
return LoadFeature<uint32_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_INT64: {
return LoadFeature<int64_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_UINT64: {
return LoadFeature<uint64_t>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_FLOAT32: {
return LoadFeature<float>(tensor, i_col, columns_blob, columns_json);
}
case DataType::DE_FLOAT64: {
return LoadFeature<double>(tensor, i_col, columns_blob, columns_json);
Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob,
const mindrecord::json &columns_json) {
for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) {
auto column_name = columns_to_load_[i_col];
// Initialize column parameters
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0;
mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType;
uint64_t column_data_type_size = 1;
std::vector<int64_t> column_shape;
// Get column data
auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName(
column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size,
&column_shape);
if (has_column == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader.");
}
default: {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"mindrecord column list type does not match any known types");
std::shared_ptr<Tensor> tensor;
const ColDescriptor &column = data_schema_->column(i_col);
DataType type = column.type();
// Set shape
auto num_elements = n_bytes / column_data_type_size;
if (column.hasShape()) {
auto new_shape = TensorShape(column.shape());
RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_elements), &new_shape));
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data));
} else {
std::vector<dsize_t> shapeDetails = {static_cast<dsize_t>(num_elements)};
auto new_shape = TensorShape(shapeDetails);
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data));
}
tensor_row->push_back(std::move(tensor));
}
return Status::OK();
}
Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) {
......
......@@ -23,6 +23,7 @@
#include <queue>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -31,6 +32,7 @@
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/queue.h"
#include "dataset/util/status.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_reader.h"
#include "mindrecord/include/common/shard_utils.h"
......@@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp {
Status Init();
Status SetColumnsBlob();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
......@@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp {
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
// Parses a single cell and puts the data into a tensor
// @param tensor - the tensor to put the parsed data in
// @param i_col - the id of column to parse
// @param tensor_row - the tensor row to put the parsed data in
// @param columns_blob - the blob data received from the reader
// @param columns_json - the data for fields received from the reader
template <typename T>
Status LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col, const std::vector<uint8_t> &columns_blob,
const mindrecord::json &columns_json) const;
Status SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col,
const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const;
static Status LoadBlob(TensorShape *new_shape, const unsigned char **data, const std::vector<uint8_t> &columns_blob,
const int32_t pos, const ColDescriptor &column);
// Get shape and data (scalar or array) for tensor to be created (for floats and doubles)
// @param new_shape - the shape of tensor to be created.
// @param array_data - the array where data should be put in
// @param column_name - name of current column to be processed
// @param columns_json - the data for fields received from the reader
// @param column - description of current column from schema
// @param use_double - boolean to choose between float32 and float64
template <typename T>
static Status LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double);
// Get shape and data (scalar or array) for tensor to be created (for integers)
// @param new_shape - the shape of tensor to be created.
// @param array_data - the array where data should be put in
// @param column_name - name of current column to be processed
// @param columns_json - the data for fields received from the reader
// @param column - description of current column from schema
template <typename T>
static Status LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column);
static Status LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name,
const mindrecord::json &columns_json);
// Get a single float value from the given json
// @param value - the float to put the value in
// @param arrayData - the given json containing the float
// @param use_double - boolean to choose between float32 and float64
template <typename T>
static Status GetFloat(T *value, const mindrecord::json &data, bool use_double);
// Get a single integer value from the given json
// @param value - the integer to put the value in
// @param arrayData - the given json containing the integer
template <typename T>
static Status GetInt(T *value, const mindrecord::json &data);
Status LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob,
const mindrecord::json &columns_json);
Status FetchBlockBuffer(const int32_t &buffer_id);
......
......@@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) {
.def("launch", &ShardReader::Launch)
.def("get_header", &ShardReader::GetShardHeader)
.def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next",
(std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy)
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
ShardReader::GetNextPy)
.def("finish", &ShardReader::Finish)
.def("close", &ShardReader::Close);
}
......
......@@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4;
enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel };
const char kVersion[] = "3.0";
const std::vector<std::string> kSupportedVersion = {"2.0", kVersion};
enum ShardType {
kNLP = 0,
kCV = 1,
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_
#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_header.h"
namespace mindspore {
namespace mindrecord {
const uint64_t kUnsignedOne = 1;
const uint64_t kBitsOfByte = 8;
const uint64_t kDataTypeBits = 2;
const uint64_t kNumDataOfByte = 4;
const uint64_t kBytesOfColumnLen = 4;
const uint64_t kDataTypeBitMask = 3;
const uint64_t kDataTypes = 6;
enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type };
enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound };
enum ColumnDataType {
ColumnBytes = 0,
ColumnString = 1,
ColumnInt32 = 2,
ColumnInt64 = 3,
ColumnFloat32 = 4,
ColumnFloat64 = 5,
ColumnNoDataType = 6
};
// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"};
const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8};
const std::vector<std::string> ColumnDataTypeNameNormalized = {"uint8", "uint8", "int32",
"int64", "float32", "float64"};
const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {
{"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32},
{"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}};
class ShardColumn {
public:
explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true);
~ShardColumn() = default;
/// \brief get column value by column name
MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);
/// \brief compress blob
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob);
/// \brief check if blob compressed
bool CheckCompressBlob() const { return has_compress_blob_; }
uint64_t GetNumBlobColumn() const { return num_blob_column_; }
std::vector<std::string> GetColumnName() { return column_name_; }
std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; }
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }
/// \brief get column value from blob
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *n_bytes);
private:
/// \brief get column value from json
MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
/// \brief get float value from json
template <typename T>
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
/// \brief get integer value from json
template <typename T>
MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value);
/// \brief get column offset address and size from blob
MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx);
/// \brief check if column name is available
ColumnCategory CheckColumnName(const std::string &column_name);
/// \brief compress integer column
static vector<uint8_t> CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type);
/// \brief uncompress integer array column
template <typename T>
static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx);
/// \brief convert big-endian bytes to unsigned int
/// \param bytes_array bytes array
/// \param pos shift address in bytes array
/// \param i_type integer type
/// \return unsigned int
static uint64_t BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &i_type);
/// \brief convert unsigned int to big-endian bytes
/// \param value integer value
/// \param i_type integer type
/// \return bytes
static std::vector<uint8_t> UIntToBytesBig(uint64_t value, const IntegerType &i_type);
/// \brief convert unsigned int to little-endian bytes
/// \param value integer value
/// \param i_type integer type
/// \return bytes
static std::vector<uint8_t> UIntToBytesLittle(uint64_t value, const IntegerType &i_type);
/// \brief convert unsigned int to little-endian bytes
/// \param bytes_array bytes array
/// \param pos shift address in bytes array
/// \param src_i_type source integer typ0e
/// \param dst_i_type (output), destination integer type
/// \return integer
static int64_t BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr);
private:
std::vector<std::string> column_name_; // column name list
std::vector<ColumnDataType> column_data_type_; // column data type list
std::vector<std::vector<int64_t>> column_shape_; // column shape list
std::unordered_map<string, uint64_t> column_name_id_; // column name id map
std::vector<std::string> blob_column_; // blob column list
std::unordered_map<std::string, uint64_t> blob_column_id_; // blob column name id map
bool has_compress_blob_; // if has compress blob
uint64_t num_blob_column_; // number of blob columns
};
} // namespace mindrecord
} // namespace mindspore
#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_
......@@ -118,8 +118,6 @@ class ShardHeader {
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
const string GetVersion() { return version_; }
std::vector<std::string> SerializeHeader();
MSRStatus PagesToFile(const std::string dump_file_name);
......@@ -175,7 +173,6 @@ class ShardHeader {
uint32_t shard_count_;
uint64_t header_size_;
uint64_t page_size_;
string version_ = "2.0";
std::shared_ptr<Index> index_;
std::vector<std::string> shard_addresses_;
......
......@@ -43,6 +43,7 @@
#include <vector>
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_operator.h"
......@@ -111,6 +112,10 @@ class ShardReader {
/// \return the metadata
std::shared_ptr<ShardHeader> GetShardHeader() const;
/// \brief aim to get columns context
/// \return the columns
std::shared_ptr<ShardColumn> get_shard_column() const;
/// \brief get the number of shards
/// \return # of shards
int GetShardCount() const;
......@@ -185,7 +190,7 @@ class ShardReader {
/// \brief return a batch, given that one is ready, python API
/// \return a batch of images and image data
std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> GetNextPy();
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy();
/// \brief get blob filed list
/// \return blob field list
......@@ -295,16 +300,18 @@ class ShardReader {
/// \brief get number of classes
int64_t GetNumClasses(const std::string &category_field);
/// \brief get meta of header
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
/// \brief get exactly blob fields data by indices
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
std::vector<uint32_t> &ordered_selected_columns_index);
/// \brief extract uncompressed data based on column list
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
protected:
uint64_t header_size_; // header size
uint64_t page_size_; // page size
int shard_count_; // number of shards
std::shared_ptr<ShardHeader> shard_header_; // shard header
std::shared_ptr<ShardColumn> shard_column_; // shard column
std::vector<sqlite3 *> database_paths_; // sqlite handle list
std::vector<string> file_paths_; // file paths
......
......@@ -36,6 +36,7 @@
#include <utility>
#include <vector>
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_header.h"
#include "mindrecord/include/shard_index.h"
......@@ -242,7 +243,8 @@ class ShardWriter {
std::vector<std::string> file_paths_; // file paths
std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles
std::shared_ptr<ShardHeader> shard_header_; // shard headers
std::shared_ptr<ShardHeader> shard_header_; // shard header
std::shared_ptr<ShardColumn> shard_column_; // shard columns
std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info
......
......@@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
shard_header_ = std::make_shared<ShardHeader>(sh);
header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize();
// version < 3.0
if (first_meta_data["version"] < kVersion) {
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
} else {
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
}
num_rows_ = 0;
auto row_group_summary = ReadRowGroupSummary();
for (const auto &rg : row_group_summary) {
......@@ -226,6 +232,8 @@ void ShardReader::Close() {
std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
std::shared_ptr<ShardColumn> ShardReader::get_shard_column() const { return shard_column_; }
int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
int ShardReader::GetNumRows() const { return num_rows_; }
......@@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
return SUCCESS;
}
std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns(
std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) {
std::vector<uint8_t> exactly_blob_fields_bytes;
auto uint64_from_bytes = [&](int64_t pos) {
uint64_t result = 0;
for (uint64_t n = 0; n < kInt64Len; n++) {
result = (result << 8) + blob_fields_bytes[pos + n];
}
return result;
};
// get the exactly blob fields
uint32_t current_index = 0;
uint64_t current_offset = 0;
uint64_t data_len = uint64_from_bytes(current_offset);
while (current_offset < blob_fields_bytes.size()) {
if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(),
[&current_index](uint32_t &index) { return index == current_index; })) {
exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset,
blob_fields_bytes.begin() + current_offset + kInt64Len + data_len);
}
current_index++;
current_offset += kInt64Len + data_len;
data_len = uint64_from_bytes(current_offset);
}
return exactly_blob_fields_bytes;
}
TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) {
// All tasks are done
if (task_id >= static_cast<int>(tasks_.Size())) {
......@@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
}
// extract the exactly blob bytes by selected columns
std::vector<uint8_t> images_with_exact_columns;
if (selected_columns_.size() == 0) {
images_with_exact_columns = images;
} else {
auto blob_fields = GetBlobFields();
std::vector<uint32_t> ordered_selected_columns_index;
uint32_t index = 0;
for (auto &blob_field : blob_fields.second) {
for (auto &field : selected_columns_) {
if (field.compare(blob_field) == 0) {
ordered_selected_columns_index.push_back(index);
break;
}
}
index++;
}
if (ordered_selected_columns_index.size() != 0) {
// extract the images
if (blob_fields.second.size() == 1) {
if (ordered_selected_columns_index.size() == 1) {
images_with_exact_columns = images;
}
} else {
images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index);
}
}
}
// Deliver batch data to output map
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task)));
batch.emplace_back(std::move(images), std::move(std::get<2>(task)));
return std::make_pair(SUCCESS, std::move(batch));
}
......@@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con
return std::move(ret.second);
}
std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> ShardReader::GetNextPy() {
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardReader::UnCompressBlob(
const std::vector<uint8_t> &raw_blob_data) {
auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
auto blob_fields = GetBlobFields().second;
std::vector<std::vector<uint8_t>> blob_data;
for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue;
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0;
auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << ".";
return {FAILED, std::vector<std::vector<uint8_t>>(blob_fields.size(), std::vector<uint8_t>())};
}
if (data == nullptr) {
data = reinterpret_cast<const unsigned char *>(data_ptr.get());
}
std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
blob_data.push_back(column);
}
return {SUCCESS, blob_data};
}
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> ShardReader::GetNextPy() {
auto res = GetNext();
vector<std::tuple<std::vector<uint8_t>, pybind11::object>> jsonData;
std::transform(res.begin(), res.end(), std::back_inserter(jsonData),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> data;
std::transform(res.begin(), res.end(), std::back_inserter(data),
[this](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
auto ret = UnCompressBlob(std::get<0>(item));
return std::make_tuple(ret.second, std::move(obj));
});
return jsonData;
return data;
}
void ShardReader::Reset() {
......
......@@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
MS_LOG(ERROR) << "Open file failed";
return FAILED;
}
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
return SUCCESS;
}
......@@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
shard_header_ = header_data;
shard_header_->SetHeaderSize(header_size_);
shard_header_->SetPageSize(page_size_);
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
return SUCCESS;
}
......@@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
MS_LOG(ERROR) << "IO error / there is no free disk to be used";
return FAILED;
}
// compress blob
if (shard_column_->CheckCompressBlob()) {
for (auto &blob : blob_data) {
blob = shard_column_->CompressBlob(blob);
}
}
// Add 4-bytes dummy blob data if no any blob fields
if (blob_data.size() == 0 && raw_data.size() > 0) {
blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindrecord/include/shard_column.h"
#include "common/utils.h"
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_error.h"
namespace mindspore {
namespace mindrecord {
ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) {
auto first_schema = shard_header->GetSchemas()[0];
auto schema = first_schema->GetSchema()["schema"];
bool has_integer_array = false;
for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
const std::string &column_name = it.key();
column_name_.push_back(column_name);
json it_value = it.value();
std::string str_type = it_value["type"];
column_data_type_.push_back(ColumnDataTypeMap.at(str_type));
if (it_value.find("shape") != it_value.end()) {
std::vector<int64_t> vec(it_value["shape"].size());
std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin());
column_shape_.push_back(vec);
if (str_type == "int32" || str_type == "int64") {
has_integer_array = true;
}
} else {
std::vector<int64_t> vec = {};
column_shape_.push_back(vec);
}
}
for (uint64_t i = 0; i < column_name_.size(); i++) {
column_name_id_[column_name_[i]] = i;
}
auto blob_fields = first_schema->GetBlobFields();
for (const auto &field : blob_fields) {
blob_column_.push_back(field);
}
for (uint64_t i = 0; i < blob_column_.size(); i++) {
blob_column_id_[blob_column_[i]] = i;
}
has_compress_blob_ = (compress_integer && has_integer_array);
num_blob_column_ = blob_column_.size();
}
MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape) {
// Skip if column not found
auto column_category = CheckColumnName(column_name);
if (column_category == ColumnNotFound) {
return FAILED;
}
// Get data type and size
auto column_id = column_name_id_[column_name];
*column_data_type = column_data_type_[column_id];
*column_data_type_size = ColumnDataTypeSize[*column_data_type];
*column_shape = column_shape_[column_id];
// Retrieve value from json
if (column_category == ColumnInRaw) {
if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) {
MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << ".";
return FAILED;
}
*data = reinterpret_cast<const unsigned char *>(data_ptr->get());
return SUCCESS;
}
// Retrieve value from blob
if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) {
MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << ".";
return FAILED;
}
if (*data == nullptr) {
*data = reinterpret_cast<const unsigned char *>(data_ptr->get());
}
return SUCCESS;
}
MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) {
auto column_id = column_name_id_[column_name];
auto column_data_type = column_data_type_[column_id];
// Initialize num bytes
*n_bytes = ColumnDataTypeSize[column_data_type];
auto json_column_value = columns_json[column_name];
switch (column_data_type) {
case ColumnFloat32: {
return GetFloat<float>(data_ptr, json_column_value, false);
}
case ColumnFloat64: {
return GetFloat<double>(data_ptr, json_column_value, true);
}
case ColumnInt32: {
return GetInt<int32_t>(data_ptr, json_column_value);
}
case ColumnInt64: {
return GetInt<int64_t>(data_ptr, json_column_value);
}
default: {
// Convert string to c_str
std::string tmp_string = json_column_value;
*n_bytes = tmp_string.size();
auto data = reinterpret_cast<const unsigned char *>(common::SafeCStr(tmp_string));
*data_ptr = std::make_unique<unsigned char[]>(*n_bytes);
for (uint32_t i = 0; i < *n_bytes; i++) {
(*data_ptr)[i] = *(data + i);
}
break;
}
}
return SUCCESS;
}
template <typename T>
MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value,
bool use_double) {
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
if (!json_column_value.is_string() && !json_column_value.is_number()) {
MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ").";
return FAILED;
}
if (json_column_value.is_number()) {
array_data[0] = json_column_value;
} else {
// Convert string to float
try {
if (use_double) {
array_data[0] = json_column_value.get<double>();
} else {
array_data[0] = json_column_value.get<float>();
}
} catch (json::exception &e) {
MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ").";
return FAILED;
}
}
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
*data_ptr = std::make_unique<unsigned char[]>(sizeof(T));
for (uint32_t i = 0; i < sizeof(T); i++) {
(*data_ptr)[i] = *(data + i);
}
return SUCCESS;
}
template <typename T>
MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) {
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
int64_t temp_value;
bool less_than_zero = false;
if (json_column_value.is_number_integer()) {
const json json_zero = 0;
if (json_column_value < json_zero) less_than_zero = true;
temp_value = json_column_value;
} else if (json_column_value.is_string()) {
std::string string_value = json_column_value;
if (!string_value.empty() && string_value[0] == '-') {
try {
temp_value = std::stoll(string_value);
less_than_zero = true;
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
return FAILED;
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
return FAILED;
}
} else {
try {
temp_value = static_cast<int64_t>(std::stoull(string_value));
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
return FAILED;
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
return FAILED;
}
}
} else {
MS_LOG(ERROR) << "Conversion to int failed.";
return FAILED;
}
if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
MS_LOG(ERROR) << "Conversion to int failed. Out of range";
return FAILED;
}
array_data[0] = static_cast<T>(temp_value);
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
*data_ptr = std::make_unique<unsigned char[]>(sizeof(T));
for (uint32_t i = 0; i < sizeof(T); i++) {
(*data_ptr)[i] = *(data + i);
}
return SUCCESS;
}
MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *n_bytes) {
uint64_t offset_address = 0;
auto column_id = column_name_id_[column_name];
if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) {
return FAILED;
}
auto column_data_type = column_data_type_[column_id];
if (has_compress_blob_ && column_data_type == ColumnInt32) {
if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
return FAILED;
}
} else if (has_compress_blob_ && column_data_type == ColumnInt64) {
if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
return FAILED;
}
} else {
*data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address]));
}
return SUCCESS;
}
ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
auto it_column = column_name_id_.find(column_name);
if (it_column == column_name_id_.end()) {
return ColumnNotFound;
}
auto it_blob = blob_column_id_.find(column_name);
return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob;
}
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) {
// Skip if no compress columns
if (!CheckCompressBlob()) return blob;
std::vector<uint8_t> dst_blob;
uint64_t i_src = 0;
for (int64_t i = 0; i < num_blob_column_; i++) {
// Get column data type
auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]];
auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type;
// Compress and return is blob has 1 column only
if (num_blob_column_ == 1) {
return CompressInt(blob, int_type);
}
// Just copy and continue if column dat type is not int32/int64
uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type);
if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) {
dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes);
i_src += kInt64Len + num_bytes;
continue;
}
// Get column slice in source blob
std::vector<uint8_t> blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes);
// Compress column
auto dst_blob_slice = CompressInt(blob_slice, int_type);
// Get new column size
auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type);
// Append new colmn size
dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end());
// Append new colmn data
dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end());
i_src += kInt64Len + num_bytes;
}
MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << ".";
return dst_blob;
}
vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
uint64_t i_size = kUnsignedOne << int_type;
// Get number of elements
uint64_t src_n_int = src_bytes.size() / i_size;
// Calculate bitmap size (bytes)
uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte;
// Initilize destination blob, more space than needed, will be resized
vector<uint8_t> dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0);
// Write number of elements to destination blob
vector<uint8_t> size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type);
for (uint64_t n = 0; n < kBytesOfColumnLen; n++) {
dst_bytes[n] = size_by_bytes[n];
}
// Write compressed int
uint64_t i_dst = kBytesOfColumnLen + bitmap_size;
for (uint64_t i = 0; i < src_n_int; i++) {
// Initialize destination data type
IntegerType dst_int_type = kInt8Type;
// Shift to next int position
uint64_t pos = i * (kUnsignedOne << int_type);
// Narrow down this int
int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);
// Write this int to destination blob
uint64_t u_n = *reinterpret_cast<uint64_t *>(&i_n);
auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type);
for (uint64_t j = 0; j < (kUnsignedOne << dst_int_type); j++) {
dst_bytes[i_dst++] = temp_bytes[j];
}
// Update date type in bit map
dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |=
(dst_int_type << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte))));
}
// Resize destination blob
dst_bytes.resize(i_dst);
MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << ".";
return dst_bytes;
}
MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx) {
if (num_blob_column_ == 1) {
*num_bytes = columns_blob.size();
*shift_idx = 0;
return SUCCESS;
}
auto blob_id = blob_column_id_[column_name_[column_id]];
for (int32_t i = 0; i < blob_id; i++) {
*shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type);
}
*num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type);
(*shift_idx) += kInt64Len;
return SUCCESS;
}
template <typename T>
MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes,
uint64_t shift_idx) {
auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type);
*num_bytes = sizeof(T) * num_elements;
// Parse integer array
uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte;
auto array_data = std::make_unique<T[]>(num_elements);
for (uint64_t i = 0; i < num_elements; i++) {
uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte];
uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask;
auto mr_int_type = static_cast<IntegerType>(i_type);
int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type);
i_source += (kUnsignedOne << i_type);
array_data[i] = static_cast<T>(i64);
}
auto data = reinterpret_cast<const unsigned char *>(array_data.get());
*data_ptr = std::make_unique<unsigned char[]>(*num_bytes);
memcpy(data_ptr->get(), data, *num_bytes);
return SUCCESS;
}
uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &i_type) {
uint64_t result = 0;
for (uint64_t i = 0; i < (kUnsignedOne << i_type); i++) {
result = (result << kBitsOfByte) + bytes_array[pos + i];
}
return result;
}
std::vector<uint8_t> ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) {
uint64_t n_bytes = kUnsignedOne << i_type;
std::vector<uint8_t> result(n_bytes, 0);
for (uint64_t i = 0; i < n_bytes; i++) {
result[n_bytes - 1 - i] = value & std::numeric_limits<uint8_t>::max();
value >>= kBitsOfByte;
}
return result;
}
std::vector<uint8_t> ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) {
uint64_t n_bytes = kUnsignedOne << i_type;
std::vector<uint8_t> result(n_bytes, 0);
for (uint64_t i = 0; i < n_bytes; i++) {
result[i] = value & std::numeric_limits<uint8_t>::max();
value >>= kBitsOfByte;
}
return result;
}
int64_t ShardColumn::BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &src_i_type, IntegerType *dst_i_type) {
uint64_t u_temp = 0;
for (uint64_t i = 0; i < (kUnsignedOne << src_i_type); i++) {
u_temp = (u_temp << kBitsOfByte) + bytes_array[pos + (kUnsignedOne << src_i_type) - kUnsignedOne - i];
}
int64_t i_out;
switch (src_i_type) {
case kInt8Type: {
i_out = (int8_t)(u_temp & std::numeric_limits<uint8_t>::max());
break;
}
case kInt16Type: {
i_out = (int16_t)(u_temp & std::numeric_limits<uint16_t>::max());
break;
}
case kInt32Type: {
i_out = (int32_t)(u_temp & std::numeric_limits<uint32_t>::max());
break;
}
case kInt64Type: {
i_out = (int64_t)(u_temp & std::numeric_limits<uint64_t>::max());
break;
}
default: {
i_out = 0;
}
}
if (!dst_i_type) {
return i_out;
}
if (i_out >= static_cast<int64_t>(std::numeric_limits<int8_t>::min()) &&
i_out <= static_cast<int64_t>(std::numeric_limits<int8_t>::max())) {
*dst_i_type = kInt8Type;
} else if (i_out >= static_cast<int64_t>(std::numeric_limits<int16_t>::min()) &&
i_out <= static_cast<int64_t>(std::numeric_limits<int16_t>::max())) {
*dst_i_type = kInt16Type;
} else if (i_out >= static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
i_out <= static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
*dst_i_type = kInt32Type;
} else {
*dst_i_type = kInt64Type;
}
return i_out;
}
} // namespace mindrecord
} // namespace mindspore
......@@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
json header;
header = ret.second;
header["shard_addresses"] = realAddresses;
if (header["version"] != version_) {
if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) {
MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump()
<< ", lib version is: " << version_;
<< ", lib version is: " << kVersion;
thread_status = true;
return;
}
......@@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
s += "\"shard_addresses\":" + address + ",";
s += "\"shard_id\":" + std::to_string(shardId) + ",";
s += "\"statistics\":" + stats + ",";
s += "\"version\":\"" + version_ + "\"";
s += "\"version\":\"" + std::string(kVersion) + "\"";
s += "}";
header.emplace_back(s);
}
......
......@@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema):
if not blob_fields:
return raw
# Get the order preserving sequence of columns in blob
ordered_columns = []
loaded_columns = []
if columns:
for blob_field in blob_fields:
if blob_field in columns:
ordered_columns.append(blob_field)
for column in columns:
if column in blob_fields:
loaded_columns.append(column)
else:
ordered_columns = blob_fields
blob_bytes = bytes(blob)
loaded_columns = blob_fields
def _render_raw(field, blob_data):
data_type = schema[field]['type']
......@@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
else:
raw[field] = blob_data
if len(blob_fields) == 1:
if len(ordered_columns) == 1:
_render_raw(blob_fields[0], blob_bytes)
return raw
return raw
def _int_from_bytes(xbytes: bytes) -> int:
return int.from_bytes(xbytes, 'big')
def _blob_at_position(pos):
start = 0
for _ in range(pos):
n_bytes = _int_from_bytes(blob_bytes[start : start + 8])
start += 8 + n_bytes
n_bytes = _int_from_bytes(blob_bytes[start : start + 8])
start += 8
return blob_bytes[start : start + n_bytes]
for i, blob_field in enumerate(ordered_columns):
_render_raw(blob_field, _blob_at_position(i))
for i, blob_field in enumerate(loaded_columns):
_render_raw(blob_field, bytes(blob[i]))
return raw
......@@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
MINDRECORD_FILE = "./cifar100.mindrecord"
def test_cifar100_to_mindrecord_without_index_fields():
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
remove_file(MINDRECORD_FILE)
yield "yield_fixture_data"
remove_file(MINDRECORD_FILE)
def test_cifar100_to_mindrecord_without_index_fields(fixture_file):
"""test transform cifar100 dataset to mindrecord without index fields."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
ret = cifar100_transformer.transform()
......@@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields():
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar100_to_mindrecord():
def test_cifar100_to_mindrecord(fixture_file):
"""test transform cifar100 dataset to mindrecord."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
cifar100_transformer.transform(['fine_label', 'coarse_label'])
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def read():
......@@ -77,8 +82,7 @@ def read():
assert count == 4
reader.close()
def test_cifar100_to_mindrecord_illegal_file_name():
def test_cifar100_to_mindrecord_illegal_file_name(fixture_file):
"""
test transform cifar100 dataset to mindrecord
when file name contains illegal character.
......@@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name():
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_start_with_space():
def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when file name starts with space.
......@@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space():
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_contain_space():
def test_cifar100_to_mindrecord_filename_contain_space(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when file name contains space.
......@@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space():
cifar100_transformer.transform()
assert os.path.exists(filename)
assert os.path.exists(filename + "_test")
os.remove("{}".format(filename))
os.remove("{}.db".format(filename))
os.remove("{}".format(filename + "_test"))
os.remove("{}.db".format(filename + "_test"))
def test_cifar100_to_mindrecord_directory():
def test_cifar100_to_mindrecord_directory(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when destination path is directory.
......@@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory():
CIFAR100_DIR)
cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_equals_cifar100():
def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when destination path equals source path.
......
......@@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS
CIFAR10_DIR = "../data/mindrecord/testCifar10Data"
MINDRECORD_FILE = "./cifar10.mindrecord"
def test_cifar10_to_mindrecord_without_index_fields():
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
remove_file(MINDRECORD_FILE)
yield "yield_fixture_data"
remove_file(MINDRECORD_FILE)
@pytest.fixture
def fixture_space_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
x = "./yes ok"
remove_file(x)
yield "yield_fixture_data"
remove_file(x)
def test_cifar10_to_mindrecord_without_index_fields(fixture_file):
"""test transform cifar10 dataset to mindrecord without index fields."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
cifar10_transformer.transform()
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar10_to_mindrecord():
def test_cifar10_to_mindrecord(fixture_file):
"""test transform cifar10 dataset to mindrecord."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
cifar10_transformer.transform(['label'])
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar10_to_mindrecord_with_return():
def test_cifar10_to_mindrecord_with_return(fixture_file):
"""test transform cifar10 dataset to mindrecord."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
ret = cifar10_transformer.transform(['label'])
......@@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return():
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def read():
......@@ -90,8 +109,7 @@ def read():
assert count == 4
reader.close()
def test_cifar10_to_mindrecord_illegal_file_name():
def test_cifar10_to_mindrecord_illegal_file_name(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when file name contains illegal character.
......@@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name():
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
cifar10_transformer.transform()
def test_cifar10_to_mindrecord_filename_start_with_space():
def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when file name starts with space.
......@@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space():
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
cifar10_transformer.transform()
def test_cifar10_to_mindrecord_filename_contain_space():
def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file):
"""
test transform cifar10 dataset to mindrecord
when file name contains space.
......@@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space():
cifar10_transformer.transform()
assert os.path.exists(filename)
assert os.path.exists(filename + "_test")
os.remove("{}".format(filename))
os.remove("{}.db".format(filename))
os.remove("{}".format(filename + "_test"))
os.remove("{}.db".format(filename + "_test"))
def test_cifar10_to_mindrecord_directory():
def test_cifar10_to_mindrecord_directory(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when destination path is directory.
......
......@@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images"
MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord"
PARTITION_NUMBER = 4
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file():
x = MINDRECORD_FILE
remove_one_file(x)
x = MINDRECORD_FILE + ".db"
remove_one_file(x)
for i in range(PARTITION_NUMBER):
x = MINDRECORD_FILE + str(i)
remove_one_file(x)
x = MINDRECORD_FILE + str(i) + ".db"
remove_one_file(x)
remove_file()
yield "yield_fixture_data"
remove_file()
def read(filename):
"""test file reade"""
......@@ -38,8 +58,7 @@ def read(filename):
assert count == 20
reader.close()
def test_imagenet_to_mindrecord():
def test_imagenet_to_mindrecord(fixture_file):
"""test transform imagenet dataset to mindrecord."""
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
MINDRECORD_FILE, PARTITION_NUMBER)
......@@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord():
assert os.path.exists(MINDRECORD_FILE + str(i))
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
read(MINDRECORD_FILE + "0")
for i in range(PARTITION_NUMBER):
os.remove(MINDRECORD_FILE + str(i))
os.remove(MINDRECORD_FILE + str(i) + ".db")
def test_imagenet_to_mindrecord_default_partition_number():
def test_imagenet_to_mindrecord_default_partition_number(fixture_file):
"""
test transform imagenet dataset to mindrecord
when partition number is default.
......@@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number():
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + ".db")
read(MINDRECORD_FILE)
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
def test_imagenet_to_mindrecord_partition_number_0():
def test_imagenet_to_mindrecord_partition_number_0(fixture_file):
"""
test transform imagenet dataset to mindrecord
when partition number is 0.
......@@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0():
MINDRECORD_FILE, 0)
imagenet_transformer.transform()
def test_imagenet_to_mindrecord_partition_number_none():
def test_imagenet_to_mindrecord_partition_number_none(fixture_file):
"""
test transform imagenet dataset to mindrecord
when partition number is none.
......@@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none():
MINDRECORD_FILE, None)
imagenet_transformer.transform()
def test_imagenet_to_mindrecord_illegal_filename():
def test_imagenet_to_mindrecord_illegal_filename(fixture_file):
"""
test transform imagenet dataset to mindrecord
when file name contains illegal character.
......
......@@ -26,6 +26,34 @@ CV_FILE_NAME = "./imagenet.mindrecord"
NLP_FILE_NAME = "./aclImdb.mindrecord"
FILES_NUM = 4
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file(file_name):
x = file_name
remove_one_file(x)
x = file_name + ".db"
remove_one_file(x)
for i in range(FILES_NUM):
x = file_name + str(i)
remove_one_file(x)
x = file_name + str(i) + ".db"
remove_one_file(x)
@pytest.fixture
def fixture_cv_file():
"""add/remove file"""
remove_file(CV_FILE_NAME)
yield "yield_fixture_data"
remove_file(CV_FILE_NAME)
@pytest.fixture
def fixture_nlp_file():
"""add/remove file"""
remove_file(NLP_FILE_NAME)
yield "yield_fixture_data"
remove_file(NLP_FILE_NAME)
def test_cv_file_writer_shard_num_none():
"""test cv file writer when shard num is None."""
......@@ -83,8 +111,7 @@ def test_lack_partition_and_db():
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
def test_lack_db():
def test_lack_db(fixture_cv_file):
"""test file reader when db file does not exist."""
create_cv_mindrecord(1)
os.remove("{}.db".format(CV_FILE_NAME))
......@@ -94,10 +121,8 @@ def test_lack_db():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
os.remove(CV_FILE_NAME)
def test_lack_some_partition_and_db():
def test_lack_some_partition_and_db(fixture_cv_file):
"""test file reader when some partition and db do not exist."""
create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
......@@ -110,16 +135,8 @@ def test_lack_some_partition_and_db():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_lack_some_partition_first():
def test_lack_some_partition_first(fixture_cv_file):
"""test file reader when first partition does not exist."""
create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
......@@ -131,14 +148,8 @@ def test_lack_some_partition_first():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_lack_some_partition_middle():
def test_lack_some_partition_middle(fixture_cv_file):
"""test file reader when some partition does not exist."""
create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
......@@ -150,14 +161,8 @@ def test_lack_some_partition_middle():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_lack_some_partition_last():
def test_lack_some_partition_last(fixture_cv_file):
"""test file reader when last partition does not exist."""
create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
......@@ -169,14 +174,8 @@ def test_lack_some_partition_last():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_mindpage_lack_some_partition():
def test_mindpage_lack_some_partition(fixture_cv_file):
"""test page reader when some partition does not exist."""
create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
......@@ -187,14 +186,8 @@ def test_mindpage_lack_some_partition():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_lack_some_db():
def test_lack_some_db(fixture_cv_file):
"""test file reader when some db does not exist."""
create_cv_mindrecord(4)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
......@@ -206,11 +199,6 @@ def test_lack_some_db():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
def test_invalid_mindrecord():
......@@ -225,8 +213,7 @@ def test_invalid_mindrecord():
in str(err.value)
os.remove(CV_FILE_NAME)
def test_invalid_db():
def test_invalid_db(fixture_cv_file):
"""test file reader when the content of db is illegal."""
create_cv_mindrecord(1)
os.remove("imagenet.mindrecord.db")
......@@ -237,11 +224,8 @@ def test_invalid_db():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
os.remove("imagenet.mindrecord")
os.remove("imagenet.mindrecord.db")
def test_overwrite_invalid_mindrecord():
def test_overwrite_invalid_mindrecord(fixture_cv_file):
"""test file writer when overwrite invalid mindreocrd file."""
with open(CV_FILE_NAME, 'w') as f:
f.write('just for test')
......@@ -250,10 +234,8 @@ def test_overwrite_invalid_mindrecord():
assert '[MRMOpenError]: error_code: 1347690596, ' \
'error_msg: MindRecord File could not open successfully.' \
in str(err.value)
os.remove(CV_FILE_NAME)
def test_overwrite_invalid_db():
def test_overwrite_invalid_db(fixture_cv_file):
"""test file writer when overwrite invalid db file."""
with open('imagenet.mindrecord.db', 'w') as f:
f.write('just for test')
......@@ -261,11 +243,8 @@ def test_overwrite_invalid_db():
create_cv_mindrecord(1)
assert '[MRMGenerateIndexError]: error_code: 1347690612, ' \
'error_msg: Failed to generate index.' in str(err.value)
os.remove("imagenet.mindrecord")
os.remove("imagenet.mindrecord.db")
def test_read_after_close():
def test_read_after_close(fixture_cv_file):
"""test file reader when close read."""
create_cv_mindrecord(1)
reader = FileReader(CV_FILE_NAME)
......@@ -275,11 +254,8 @@ def test_read_after_close():
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 0
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_file_read_after_read():
def test_file_read_after_read(fixture_cv_file):
"""test file reader when finish read."""
create_cv_mindrecord(1)
reader = FileReader(CV_FILE_NAME)
......@@ -295,8 +271,6 @@ def test_file_read_after_read():
cnt = cnt + 1
logger.info("#item{}: {}".format(index, x))
assert cnt == 0
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_file_writer_shard_num_greater_than_1000():
......@@ -312,8 +286,7 @@ def test_add_index_without_add_schema():
fw.add_index(["label"])
assert 'Failed to get meta info' in str(err.value)
def test_mindpage_pageno_pagesize_not_int():
def test_mindpage_pageno_pagesize_not_int(fixture_cv_file):
"""test page reader when some partition does not exist."""
create_cv_mindrecord(4)
reader = MindPage(CV_FILE_NAME + "0")
......@@ -342,14 +315,8 @@ def test_mindpage_pageno_pagesize_not_int():
with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
reader.read_at_page_by_id(99999, 0, 1)
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_mindpage_filename_not_exist():
def test_mindpage_filename_not_exist(fixture_cv_file):
"""test page reader when some partition does not exist."""
create_cv_mindrecord(4)
reader = MindPage(CV_FILE_NAME + "0")
......@@ -374,6 +341,3 @@ def test_mindpage_filename_not_exist():
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
......@@ -14,6 +14,7 @@
"""test mnist to mindrecord tool"""
import cv2
import gzip
import pytest
import numpy as np
import os
......@@ -27,6 +28,34 @@ PARTITION_NUM = 4
IMAGE_SIZE = 28
NUM_CHANNELS = 1
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file():
x = "mnist_train.mindrecord"
remove_one_file(x)
x = "mnist_train.mindrecord.db"
remove_one_file(x)
x = "mnist_test.mindrecord"
remove_one_file(x)
x = "mnist_test.mindrecord.db"
remove_one_file(x)
for i in range(PARTITION_NUM):
x = "mnist_train.mindrecord" + str(i)
remove_one_file(x)
x = "mnist_train.mindrecord" + str(i) + ".db"
remove_one_file(x)
x = "mnist_test.mindrecord" + str(i)
remove_one_file(x)
x = "mnist_test.mindrecord" + str(i) + ".db"
remove_one_file(x)
remove_file()
yield "yield_fixture_data"
remove_file()
def read(train_name, test_name):
"""test file reader"""
......@@ -51,7 +80,7 @@ def read(train_name, test_name):
reader.close()
def test_mnist_to_mindrecord():
def test_mnist_to_mindrecord(fixture_file):
"""test transform mnist dataset to mindrecord."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
mnist_transformer.transform()
......@@ -60,13 +89,7 @@ def test_mnist_to_mindrecord():
read("mnist_train.mindrecord", "mnist_test.mindrecord")
os.remove("{}".format("mnist_train.mindrecord"))
os.remove("{}.db".format("mnist_train.mindrecord"))
os.remove("{}".format("mnist_test.mindrecord"))
os.remove("{}.db".format("mnist_test.mindrecord"))
def test_mnist_to_mindrecord_compare_data():
def test_mnist_to_mindrecord_compare_data(fixture_file):
"""test transform mnist dataset to mindrecord and compare data."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
mnist_transformer.transform()
......@@ -121,21 +144,10 @@ def test_mnist_to_mindrecord_compare_data():
assert np.array(x['label']) == label
reader.close()
os.remove("{}".format("mnist_train.mindrecord"))
os.remove("{}.db".format("mnist_train.mindrecord"))
os.remove("{}".format("mnist_test.mindrecord"))
os.remove("{}.db".format("mnist_test.mindrecord"))
def test_mnist_to_mindrecord_multi_partition():
def test_mnist_to_mindrecord_multi_partition(fixture_file):
"""test transform mnist dataset to multiple mindrecord files."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM)
mnist_transformer.transform()
read("mnist_train.mindrecord0", "mnist_test.mindrecord0")
for i in range(PARTITION_NUM):
os.remove("{}".format("mnist_train.mindrecord" + str(i)))
os.remove("{}.db".format("mnist_train.mindrecord" + str(i)))
os.remove("{}".format("mnist_test.mindrecord" + str(i)))
os.remove("{}.db".format("mnist_test.mindrecord" + str(i)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册