提交 235c6997 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!585 回退 'Pull Request !182 : Tuning mindrecord writer performance'

Merge pull request !585 from panfengfeng/revert-merge-182-master
# MindRecord generating guidelines
<!-- TOC -->
- [MindRecord generating guidelines](#mindrecord-generating-guidelines)
- [Create work space](#create-work-space)
- [Implement data generator](#implement-data-generator)
- [Run data generator](#run-data-generator)
<!-- /TOC -->
## Create work space
Assume the dataset name is 'xyz'
* Create work space from template
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
cp -r template xyz
```
## Implement data generator
Edit dictionary data generator
* Edit file
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
vi xyz/mr_api.py
```
Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented
- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks.
- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number()
Tricky for parallel run
- For imagenet, one directory can be a task.
- For TFRecord with multiple files, each file can be a task.
- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K)
## Run data generator
* run python script
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
python writer.py --mindrecord_script imagenet [...]
```
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord writer.
Two API must be implemented,
1. mindrecord_task_number()
# Return number of parallel tasks. return 1 if no parallel
2. mindrecord_dict_data(task_id)
# Yield data for one task
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
"""
import argparse
import os
import pickle
######## mindrecord_schema begin ##########
mindrecord_schema = {"label": {"type": "int64"},
"data": {"type": "bytes"},
"file_name": {"type": "string"}}
######## mindrecord_schema end ##########
######## Frozen code begin ##########
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle:
ARG_LIST = pickle.load(mindrecord_argument_file_handle)
######## Frozen code end ##########
parser = argparse.ArgumentParser(description='Mind record imagenet example')
parser.add_argument('--label_file', type=str, default="", help='label file')
parser.add_argument('--image_dir', type=str, default="", help='images directory')
######## Frozen code begin ##########
args = parser.parse_args(ARG_LIST)
print(args)
######## Frozen code end ##########
def _user_defined_private_func():
"""
Internal function for tasks list
Return:
tasks list
"""
if not os.path.exists(args.label_file):
raise IOError("map file {} not exists".format(args.label_file))
label_dict = {}
with open(args.label_file) as file_handle:
line = file_handle.readline()
while line:
labels = line.split(" ")
label_dict[labels[1]] = labels[0]
line = file_handle.readline()
# get all the dir which are n02087046, n02094114, n02109525
dir_paths = {}
for item in label_dict:
real_path = os.path.join(args.image_dir, label_dict[item])
if not os.path.isdir(real_path):
print("{} dir is not exist".format(real_path))
continue
dir_paths[item] = real_path
if not dir_paths:
print("not valid image dir in {}".format(args.image_dir))
return {}, {}
dir_list = []
for label in dir_paths:
dir_list.append(label)
return dir_list, dir_paths
dir_list_global, dir_paths_global = _user_defined_private_func()
def mindrecord_task_number():
"""
Get task size.
Return:
number of tasks
"""
return len(dir_list_global)
def mindrecord_dict_data(task_id):
"""
Get data dict.
Yields:
data (dict): data row which is dict.
"""
# get the filename, label and image binary as a dict
label = dir_list_global[task_id]
for item in os.listdir(dir_paths_global[label]):
file_name = os.path.join(dir_paths_global[label], item)
if not item.endswith("JPEG") and not item.endswith(
"jpg") and not item.endswith("jpeg"):
print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name))
continue
data = {}
data["file_name"] = str(file_name)
data["label"] = int(label)
# get the image data
image_file = open(file_name, "rb")
image_bytes = image_file.read()
image_file.close()
data["data"] = image_bytes
yield data
#!/bin/bash
rm /tmp/imagenet/mr/*
python writer.py --mindrecord_script imagenet \
--mindrecord_file "/tmp/imagenet/mr/m" \
--mindrecord_partitions 16 \
--label_file "/tmp/imagenet/label.txt" \
--image_dir "/tmp/imagenet/jpeg"
#!/bin/bash
rm /tmp/template/*
python writer.py --mindrecord_script template \
--mindrecord_file "/tmp/template/m" \
--mindrecord_partitions 4
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord writer.
Two API must be implemented,
1. mindrecord_task_number()
# Return number of parallel tasks. return 1 if no parallel
2. mindrecord_dict_data(task_id)
# Yield data for one task
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
"""
import argparse
import pickle
# ## Parse argument
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line
ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line
parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line
# ## Your arguments below
# parser.add_argument(...)
args = parser.parse_args(ARG_LIST) # Do NOT change this line
print(args) # Do NOT change this line
# ## Default mindrecord vars. Comment them unless default value has to be changed.
# mindrecord_index_fields = ['label']
# mindrecord_header_size = 1 << 24
# mindrecord_page_size = 1 << 25
# define global vars here if necessary
# ####### Your code below ##########
mindrecord_schema = {"label": {"type": "int32"}}
def mindrecord_task_number():
"""
Get task size.
Return:
number of tasks
"""
return 1
def mindrecord_dict_data(task_id):
"""
Get data dict.
Yields:
data (dict): data row which is dict.
"""
print("task is {}".format(task_id))
for i in range(256):
data = {}
data['label'] = i
yield data
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
######################## write mindrecord example ########################
Write mindrecord by data dictionary:
python writer.py --mindrecord_script /YourScriptPath ...
"""
import argparse
import os
import pickle
import time
from importlib import import_module
from multiprocessing import Pool
from mindspore.mindrecord import FileWriter
def _exec_task(task_id, parallel_writer=True):
"""
Execute task with specified task id
"""
print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
imagenet_iter = mindrecord_dict_data(task_id)
batch_size = 2048
transform_count = 0
while True:
data_list = []
try:
for _ in range(batch_size):
data_list.append(imagenet_iter.__next__())
transform_count += 1
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
print("transformed {} record...".format(transform_count))
except StopIteration:
if data_list:
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
print("transformed {} record...".format(transform_count))
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Mind record writer')
parser.add_argument('--mindrecord_script', type=str, default="template",
help='path where script is saved')
parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord",
help='written file name prefix')
parser.add_argument('--mindrecord_partitions', type=int, default=1,
help='number of written files')
parser.add_argument('--mindrecord_workers', type=int, default=8,
help='number of parallel workers')
args = parser.parse_known_args()
args, other_args = parser.parse_known_args()
print(args)
print(other_args)
with open('mr_argument.pickle', 'wb') as file_handle:
pickle.dump(other_args, file_handle)
try:
mr_api = import_module(args.mindrecord_script + '.mr_api')
except ModuleNotFoundError:
raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))
num_tasks = mr_api.mindrecord_task_number()
print("Write mindrecord ...")
mindrecord_dict_data = mr_api.mindrecord_dict_data
# get number of files
writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)
start_time = time.time()
# set the header size
try:
header_size = mr_api.mindrecord_header_size
writer.set_header_size(header_size)
except AttributeError:
print("Default header size: {}".format(1 << 24))
# set the page size
try:
page_size = mr_api.mindrecord_page_size
writer.set_page_size(page_size)
except AttributeError:
print("Default page size: {}".format(1 << 25))
# get schema
try:
mindrecord_schema = mr_api.mindrecord_schema
except AttributeError:
raise RuntimeError("mindrecord_schema is not defined in mr_api.py.")
# create the schema
writer.add_schema(mindrecord_schema, "mindrecord_schema")
# add the index
try:
index_fields = mr_api.mindrecord_index_fields
writer.add_index(index_fields)
except AttributeError:
print("Default index fields: all simple fields are indexes.")
writer.open_and_set_header()
task_list = list(range(num_tasks))
# set number of workers
num_workers = args.mindrecord_workers
if num_tasks < 1:
num_tasks = 1
if num_workers > num_tasks:
num_workers = num_tasks
if num_tasks > 1:
with Pool(num_workers) as p:
p.map(_exec_task, task_list)
else:
_exec_task(0, False)
ret = writer.commit()
os.remove("{}".format("mr_argument.pickle"))
end_time = time.time()
print("--------------------------------------------")
print("END. Total time: {}".format(end_time - start_time))
print("--------------------------------------------")
......@@ -75,9 +75,12 @@ void BindShardWriter(py::module *m) {
.def("set_header_size", &ShardWriter::set_header_size)
.def("set_page_size", &ShardWriter::set_page_size)
.def("set_shard_header", &ShardWriter::SetShardHeader)
.def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
vector<vector<uint8_t>> &, bool, bool)) &
ShardWriter::WriteRawData)
.def("write_raw_data",
(MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, vector<vector<uint8_t>> &, bool)) &
ShardWriter::WriteRawData)
.def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
std::map<uint64_t, std::vector<py::handle>> &, bool)) &
ShardWriter::WriteRawData)
.def("commit", &ShardWriter::Commit);
}
......
......@@ -121,10 +121,6 @@ class ShardHeader {
std::vector<std::string> SerializeHeader();
MSRStatus PagesToFile(const std::string dump_file_name);
MSRStatus FileToPages(const std::string dump_file_name);
private:
MSRStatus InitializeHeader(const std::vector<json> &headers);
......
......@@ -18,7 +18,6 @@
#define MINDRECORD_INCLUDE_SHARD_WRITER_H_
#include <libgen.h>
#include <sys/file.h>
#include <unistd.h>
#include <algorithm>
#include <array>
......@@ -88,7 +87,7 @@ class ShardWriter {
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);
bool sign = true);
/// \brief write raw data by group size for call from python
/// \param[in] raw_data the vector of raw json data, python-handle format
......@@ -96,7 +95,7 @@ class ShardWriter {
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);
bool sign = true);
/// \brief write raw data by group size for call from python
/// \param[in] raw_data the vector of raw json data, python-handle format
......@@ -104,8 +103,7 @@ class ShardWriter {
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
bool parallel_writer = false);
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true);
private:
/// \brief write shard header data to disk
......@@ -203,34 +201,7 @@ class ShardWriter {
MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
std::map<int, std::string> &err_raw_data);
/// \brief Lock writer and save pages info
int LockWriter(bool parallel_writer = false);
/// \brief Unlock writer and save pages info
MSRStatus UnlockWriter(int fd, bool parallel_writer = false);
/// \brief Check raw data before writing
MSRStatus WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign, int *schema_count, int *row_count);
/// \brief Get full path from file name
MSRStatus GetFullPathFromFileName(const std::vector<std::string> &paths);
/// \brief Open files
MSRStatus OpenDataFiles(bool append);
/// \brief Remove lock file
MSRStatus RemoveLockFile();
/// \brief Remove lock file
MSRStatus InitLockFile();
private:
const std::string kLockFileSuffix = "_Locker";
const std::string kPageFileSuffix = "_Pages";
std::string lock_file_; // lock file for parallel run
std::string pages_file_; // temporary file of pages info for parallel run
int shard_count_; // number of files
uint64_t header_size_; // header size
uint64_t page_size_; // page size
......@@ -240,7 +211,7 @@ class ShardWriter {
std::vector<uint64_t> raw_data_size_; // Raw data size
std::vector<uint64_t> blob_data_size_; // Blob data size
std::vector<std::string> file_paths_; // file paths
std::vector<string> file_paths_; // file paths
std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles
std::shared_ptr<ShardHeader> shard_header_; // shard headers
......
......@@ -520,16 +520,13 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
for (int raw_page_id : raw_page_ids) {
auto sql = GenerateRawSQL(fields_);
if (sql.first != SUCCESS) {
MS_LOG(ERROR) << "Generate raw SQL failed";
return FAILED;
}
auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in);
if (data.first != SUCCESS) {
MS_LOG(ERROR) << "Generate raw data failed";
return FAILED;
}
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
MS_LOG(ERROR) << "Execute SQL failed";
return FAILED;
}
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
......
......@@ -40,7 +40,17 @@ ShardWriter::~ShardWriter() {
}
}
MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) {
MSRStatus ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
shard_count_ = paths.size();
if (shard_count_ > kMaxShardCount || shard_count_ == 0) {
MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0.";
return FAILED;
}
if (schema_count_ > kMaxSchemaCount) {
MS_LOG(ERROR) << "The schema Count greater than max value.";
return FAILED;
}
// Get full path from file name
for (const auto &path : paths) {
if (!CheckIsValidUtf8(path)) {
......@@ -50,7 +60,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &p
char resolved_path[PATH_MAX] = {0};
char buf[PATH_MAX] = {0};
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
MS_LOG(ERROR) << "Secure func failed";
MS_LOG(ERROR) << "Securec func failed";
return FAILED;
}
#if defined(_WIN32) || defined(_WIN64)
......@@ -72,10 +82,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &p
#endif
file_paths_.emplace_back(string(resolved_path));
}
return SUCCESS;
}
MSRStatus ShardWriter::OpenDataFiles(bool append) {
// Open files
for (const auto &file : file_paths_) {
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
......@@ -109,67 +116,6 @@ MSRStatus ShardWriter::OpenDataFiles(bool append) {
return SUCCESS;
}
MSRStatus ShardWriter::RemoveLockFile() {
// Remove temporary file
int ret = std::remove(pages_file_.c_str());
if (ret == 0) {
MS_LOG(DEBUG) << "Remove page file.";
}
ret = std::remove(lock_file_.c_str());
if (ret == 0) {
MS_LOG(DEBUG) << "Remove lock file.";
}
return SUCCESS;
}
MSRStatus ShardWriter::InitLockFile() {
if (file_paths_.size() == 0) {
MS_LOG(ERROR) << "File path not initialized.";
return FAILED;
}
lock_file_ = file_paths_[0] + kLockFileSuffix;
pages_file_ = file_paths_[0] + kPageFileSuffix;
if (RemoveLockFile() == FAILED) {
MS_LOG(ERROR) << "Remove file failed.";
return FAILED;
}
return SUCCESS;
}
MSRStatus ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
shard_count_ = paths.size();
if (shard_count_ > kMaxShardCount || shard_count_ == 0) {
MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0.";
return FAILED;
}
if (schema_count_ > kMaxSchemaCount) {
MS_LOG(ERROR) << "The schema Count greater than max value.";
return FAILED;
}
// Get full path from file name
if (GetFullPathFromFileName(paths) == FAILED) {
MS_LOG(ERROR) << "Get full path from file name failed.";
return FAILED;
}
// Open files
if (OpenDataFiles(append) == FAILED) {
MS_LOG(ERROR) << "Open data files failed.";
return FAILED;
}
// Init lock file
if (InitLockFile() == FAILED) {
MS_LOG(ERROR) << "Init lock file failed.";
return FAILED;
}
return SUCCESS;
}
MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if (!IsLegalFile(path)) {
return FAILED;
......@@ -197,28 +143,11 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
}
MSRStatus ShardWriter::Commit() {
// Read pages file
std::ifstream page_file(pages_file_.c_str());
if (page_file.good()) {
page_file.close();
if (shard_header_->FileToPages(pages_file_) == FAILED) {
MS_LOG(ERROR) << "Read pages from file failed";
return FAILED;
}
}
if (WriteShardHeader() == FAILED) {
MS_LOG(ERROR) << "Write metadata failed";
return FAILED;
}
MS_LOG(INFO) << "Write metadata successfully.";
// Remove lock file
if (RemoveLockFile() == FAILED) {
MS_LOG(ERROR) << "Remove lock file failed.";
return FAILED;
}
return SUCCESS;
}
......@@ -526,65 +455,15 @@ void ShardWriter::FillArray(int start, int end, std::map<uint64_t, vector<json>>
}
}
int ShardWriter::LockWriter(bool parallel_writer) {
if (!parallel_writer) {
return 0;
}
const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666);
if (fd >= 0) {
flock(fd, LOCK_EX);
} else {
MS_LOG(ERROR) << "Shard writer failed when locking file";
return -1;
}
// Open files
file_streams_.clear();
for (const auto &file : file_paths_) {
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary);
if (fs->fail()) {
MS_LOG(ERROR) << "File could not opened";
return -1;
}
file_streams_.push_back(fs);
}
if (shard_header_->FileToPages(pages_file_) == FAILED) {
MS_LOG(ERROR) << "Read pages from file failed";
return -1;
}
return fd;
}
MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) {
if (!parallel_writer) {
return SUCCESS;
}
if (shard_header_->PagesToFile(pages_file_) == FAILED) {
MS_LOG(ERROR) << "Write pages to file failed";
return FAILED;
}
for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
file_streams_[i]->close();
}
flock(fd, LOCK_UN);
close(fd);
return SUCCESS;
}
MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign, int *schema_count,
int *row_count) {
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign) {
// check the free disk size
auto st_space = GetDiskSize(file_paths_[0], kFreeSize);
if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) {
MS_LOG(ERROR) << "IO error / there is no free disk to be used";
return FAILED;
}
// Add 4-bytes dummy blob data if no any blob fields
if (blob_data.size() == 0 && raw_data.size() > 0) {
blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
......@@ -600,29 +479,10 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
MS_LOG(ERROR) << "Validate raw data failed";
return FAILED;
}
*schema_count = std::get<1>(v);
*row_count = std::get<2>(v);
return SUCCESS;
}
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
// Lock Writer if loading data parallel
int fd = LockWriter(parallel_writer);
if (fd < 0) {
MS_LOG(ERROR) << "Lock writer failed";
return FAILED;
}
// Get the count of schemas and rows
int schema_count = 0;
int row_count = 0;
// Serialize raw data
if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) {
MS_LOG(ERROR) << "Check raw data failed";
return FAILED;
}
int schema_count = std::get<1>(v);
int row_count = std::get<2>(v);
if (row_count == kInt0) {
MS_LOG(INFO) << "Raw data size is 0.";
......@@ -656,17 +516,11 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_d
}
MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully.";
if (UnlockWriter(fd, parallel_writer) == FAILED) {
MS_LOG(ERROR) << "Unlock writer failed";
return FAILED;
}
return SUCCESS;
}
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign,
bool parallel_writer) {
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign) {
std::map<uint64_t, std::vector<json>> raw_data_json;
std::map<uint64_t, std::vector<json>> blob_data_json;
......@@ -700,11 +554,11 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>>
MS_LOG(ERROR) << "Serialize raw data failed in write raw data";
return FAILED;
}
return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer);
return WriteRawData(raw_data_json, bin_blob_data, sign);
}
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
vector<vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
vector<vector<uint8_t>> &blob_data, bool sign) {
std::map<uint64_t, std::vector<json>> raw_data_json;
(void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
[](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
......@@ -714,7 +568,7 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>>
[](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
return std::make_pair(pair.first, std::move(json_raw_data));
});
return WriteRawData(raw_data_json, blob_data, sign, parallel_writer);
return WriteRawData(raw_data_json, blob_data, sign);
}
MSRStatus ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
......
......@@ -677,43 +677,5 @@ std::pair<std::shared_ptr<Statistics>, MSRStatus> ShardHeader::GetStatisticByID(
}
return std::make_pair(statistics_.at(statistic_id), SUCCESS);
}
MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) {
// write header content to file, dump whatever is in the file before
std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out);
if (page_out_handle.fail()) {
MS_LOG(ERROR) << "Failed in opening page file";
return FAILED;
}
auto pages = SerializePage();
for (const auto &shard_pages : pages) {
page_out_handle << shard_pages << "\n";
}
page_out_handle.close();
return SUCCESS;
}
MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
for (auto &v : pages_) { // clean pages
v.clear();
}
// attempt to open the file contains the page in json
std::ifstream page_in_handle(dump_file_name.c_str());
if (!page_in_handle.good()) {
MS_LOG(INFO) << "No page file exists.";
return SUCCESS;
}
std::string line;
while (std::getline(page_in_handle, line)) {
ParsePage(json::parse(line));
}
page_in_handle.close();
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore
......@@ -200,24 +200,13 @@ class FileWriter:
raw_data.pop(i)
logger.warning(v)
def open_and_set_header(self):
"""
Open writer and set header
"""
if not self._writer.is_open:
self._writer.open(self._paths)
if not self._writer.get_shard_header():
self._writer.set_shard_header(self._header)
def write_raw_data(self, raw_data, parallel_writer=False):
def write_raw_data(self, raw_data):
"""
Write raw data and generate sequential pair of MindRecord File and \
validate data based on predefined schema by default.
Args:
raw_data (list[dict]): List of raw data.
parallel_writer (bool, optional): Load data parallel if it equals to True (default=False).
Raises:
ParamTypeError: If index field is invalid.
......@@ -236,7 +225,7 @@ class FileWriter:
if not isinstance(each_raw, dict):
raise ParamTypeError('raw_data item', 'dict')
self._verify_based_on_schema(raw_data)
return self._writer.write_raw_data(raw_data, True, parallel_writer)
return self._writer.write_raw_data(raw_data, True)
def set_header_size(self, header_size):
"""
......
......@@ -135,7 +135,7 @@ class ShardWriter:
def get_shard_header(self):
return self._header
def write_raw_data(self, data, validate=True, parallel_writer=False):
def write_raw_data(self, data, validate=True):
"""
Write raw data of cv dataset.
......@@ -145,7 +145,6 @@ class ShardWriter:
Args:
data (list[dict]): List of raw data.
validate (bool, optional): verify data according schema if it equals to True.
parallel_writer (bool, optional): Load data parallel if it equals to True.
Returns:
MSRStatus, SUCCESS or FAILED.
......@@ -166,7 +165,7 @@ class ShardWriter:
if row_raw:
raw_data.append(row_raw)
raw_data = {0: raw_data} if raw_data else {}
ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer)
ret = self._writer.write_raw_data(raw_data, blob_data, validate)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to write dataset.")
raise MRMWriteDatasetError
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册