diff --git a/example/convert_to_mindrecord/README.md b/example/convert_to_mindrecord/README.md deleted file mode 100644 index 8d3b25e311b5868de26934b1d8ad52346b73c69e..0000000000000000000000000000000000000000 --- a/example/convert_to_mindrecord/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# MindRecord generating guidelines - - - -- [MindRecord generating guidelines](#mindrecord-generating-guidelines) - - [Create work space](#create-work-space) - - [Implement data generator](#implement-data-generator) - - [Run data generator](#run-data-generator) - - - -## Create work space - -Assume the dataset name is 'xyz' -* Create work space from template - ```shell - cd ${your_mindspore_home}/example/convert_to_mindrecord - cp -r template xyz - ``` - -## Implement data generator - -Edit dictionary data generator -* Edit file - ```shell - cd ${your_mindspore_home}/example/convert_to_mindrecord - vi xyz/mr_api.py - ``` - - Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented -- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. -- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() - - -Tricky for parallel run -- For imagenet, one directory can be a task. -- For TFRecord with multiple files, each file can be a task. -- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) - - -## Run data generator -* run python script - ```shell - cd ${your_mindspore_home}/example/convert_to_mindrecord - python writer.py --mindrecord_script imagenet [...] - ``` diff --git a/example/convert_to_mindrecord/imagenet/__init__.py b/example/convert_to_mindrecord/imagenet/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/example/convert_to_mindrecord/imagenet/mr_api.py b/example/convert_to_mindrecord/imagenet/mr_api.py deleted file mode 100644 index e569b489b56107cac61f783acd94b670a015e508..0000000000000000000000000000000000000000 --- a/example/convert_to_mindrecord/imagenet/mr_api.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -User-defined API for MindRecord writer. -Two API must be implemented, - 1. mindrecord_task_number() - # Return number of parallel tasks. return 1 if no parallel - 2. mindrecord_dict_data(task_id) - # Yield data for one task - # task_id is 0..N-1, if N is return value of mindrecord_task_number() -""" -import argparse -import os -import pickle - -######## mindrecord_schema begin ########## -mindrecord_schema = {"label": {"type": "int64"}, - "data": {"type": "bytes"}, - "file_name": {"type": "string"}} -######## mindrecord_schema end ########## - -######## Frozen code begin ########## -with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: - ARG_LIST = pickle.load(mindrecord_argument_file_handle) -######## Frozen code end ########## - -parser = argparse.ArgumentParser(description='Mind record imagenet example') -parser.add_argument('--label_file', type=str, default="", help='label file') -parser.add_argument('--image_dir', type=str, default="", help='images directory') - -######## Frozen code begin ########## -args = parser.parse_args(ARG_LIST) -print(args) -######## Frozen code end ########## - - -def _user_defined_private_func(): - """ - Internal function for tasks list - - Return: - tasks list - """ - if not os.path.exists(args.label_file): - raise IOError("map file {} not exists".format(args.label_file)) - - label_dict = {} - with open(args.label_file) as file_handle: - line = file_handle.readline() - while line: - labels = line.split(" ") - label_dict[labels[1]] = labels[0] - line = file_handle.readline() - # get all the dir which are n02087046, n02094114, n02109525 - dir_paths = {} - for item in label_dict: - real_path = os.path.join(args.image_dir, label_dict[item]) - if not os.path.isdir(real_path): - print("{} dir is not exist".format(real_path)) - continue - dir_paths[item] = real_path - - if not dir_paths: - print("not valid image dir in {}".format(args.image_dir)) - return {}, {} - - dir_list = [] - for label in dir_paths: - dir_list.append(label) - return dir_list, dir_paths - - -dir_list_global, dir_paths_global = _user_defined_private_func() - -def mindrecord_task_number(): - """ - Get task size. - - Return: - number of tasks - """ - return len(dir_list_global) - - -def mindrecord_dict_data(task_id): - """ - Get data dict. - - Yields: - data (dict): data row which is dict. - """ - - # get the filename, label and image binary as a dict - label = dir_list_global[task_id] - for item in os.listdir(dir_paths_global[label]): - file_name = os.path.join(dir_paths_global[label], item) - if not item.endswith("JPEG") and not item.endswith( - "jpg") and not item.endswith("jpeg"): - print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name)) - continue - data = {} - data["file_name"] = str(file_name) - data["label"] = int(label) - - # get the image data - image_file = open(file_name, "rb") - image_bytes = image_file.read() - image_file.close() - data["data"] = image_bytes - yield data diff --git a/example/convert_to_mindrecord/run_imagenet.sh b/example/convert_to_mindrecord/run_imagenet.sh deleted file mode 100644 index 11f5dcff75642a0b47df9f851f5a30ac61307e00..0000000000000000000000000000000000000000 --- a/example/convert_to_mindrecord/run_imagenet.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -rm /tmp/imagenet/mr/* - -python writer.py --mindrecord_script imagenet \ ---mindrecord_file "/tmp/imagenet/mr/m" \ ---mindrecord_partitions 16 \ ---label_file "/tmp/imagenet/label.txt" \ ---image_dir "/tmp/imagenet/jpeg" diff --git a/example/convert_to_mindrecord/run_template.sh b/example/convert_to_mindrecord/run_template.sh deleted file mode 100644 index a4c5142c00ad46936260424c9fd20e6ed327df85..0000000000000000000000000000000000000000 --- a/example/convert_to_mindrecord/run_template.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -rm /tmp/template/* - -python writer.py --mindrecord_script template \ ---mindrecord_file "/tmp/template/m" \ ---mindrecord_partitions 4 diff --git a/example/convert_to_mindrecord/template/__init__.py b/example/convert_to_mindrecord/template/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/example/convert_to_mindrecord/template/mr_api.py b/example/convert_to_mindrecord/template/mr_api.py deleted file mode 100644 index 3f7d7dddf0d26ff4869b8a85f4039fdee109f3e8..0000000000000000000000000000000000000000 --- a/example/convert_to_mindrecord/template/mr_api.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -User-defined API for MindRecord writer. -Two API must be implemented, - 1. mindrecord_task_number() - # Return number of parallel tasks. return 1 if no parallel - 2. mindrecord_dict_data(task_id) - # Yield data for one task - # task_id is 0..N-1, if N is return value of mindrecord_task_number() -""" -import argparse -import pickle - -# ## Parse argument - -with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line - ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line -parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line - -# ## Your arguments below -# parser.add_argument(...) - -args = parser.parse_args(ARG_LIST) # Do NOT change this line -print(args) # Do NOT change this line - - -# ## Default mindrecord vars. Comment them unless default value has to be changed. -# mindrecord_index_fields = ['label'] -# mindrecord_header_size = 1 << 24 -# mindrecord_page_size = 1 << 25 - - -# define global vars here if necessary - - -# ####### Your code below ########## -mindrecord_schema = {"label": {"type": "int32"}} - -def mindrecord_task_number(): - """ - Get task size. - - Return: - number of tasks - """ - return 1 - - -def mindrecord_dict_data(task_id): - """ - Get data dict. - - Yields: - data (dict): data row which is dict. - """ - print("task is {}".format(task_id)) - for i in range(256): - data = {} - data['label'] = i - yield data diff --git a/example/convert_to_mindrecord/writer.py b/example/convert_to_mindrecord/writer.py deleted file mode 100644 index 0a9ad5c86aabd6e3ad4eca773e68f312c2a89007..0000000000000000000000000000000000000000 --- a/example/convert_to_mindrecord/writer.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -######################## write mindrecord example ######################## -Write mindrecord by data dictionary: -python writer.py --mindrecord_script /YourScriptPath ... -""" -import argparse -import os -import pickle -import time -from importlib import import_module -from multiprocessing import Pool - -from mindspore.mindrecord import FileWriter - - -def _exec_task(task_id, parallel_writer=True): - """ - Execute task with specified task id - """ - print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) - imagenet_iter = mindrecord_dict_data(task_id) - batch_size = 2048 - transform_count = 0 - while True: - data_list = [] - try: - for _ in range(batch_size): - data_list.append(imagenet_iter.__next__()) - transform_count += 1 - writer.write_raw_data(data_list, parallel_writer=parallel_writer) - print("transformed {} record...".format(transform_count)) - except StopIteration: - if data_list: - writer.write_raw_data(data_list, parallel_writer=parallel_writer) - print("transformed {} record...".format(transform_count)) - break - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Mind record writer') - parser.add_argument('--mindrecord_script', type=str, default="template", - help='path where script is saved') - - parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord", - help='written file name prefix') - - parser.add_argument('--mindrecord_partitions', type=int, default=1, - help='number of written files') - - parser.add_argument('--mindrecord_workers', type=int, default=8, - help='number of parallel workers') - - args = parser.parse_known_args() - - args, other_args = parser.parse_known_args() - - print(args) - print(other_args) - - with open('mr_argument.pickle', 'wb') as file_handle: - pickle.dump(other_args, file_handle) - - try: - mr_api = import_module(args.mindrecord_script + '.mr_api') - except ModuleNotFoundError: - raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) - - num_tasks = mr_api.mindrecord_task_number() - - print("Write mindrecord ...") - - mindrecord_dict_data = mr_api.mindrecord_dict_data - - # get number of files - writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) - - start_time = time.time() - - # set the header size - try: - header_size = mr_api.mindrecord_header_size - writer.set_header_size(header_size) - except AttributeError: - print("Default header size: {}".format(1 << 24)) - - # set the page size - try: - page_size = mr_api.mindrecord_page_size - writer.set_page_size(page_size) - except AttributeError: - print("Default page size: {}".format(1 << 25)) - - # get schema - try: - mindrecord_schema = mr_api.mindrecord_schema - except AttributeError: - raise RuntimeError("mindrecord_schema is not defined in mr_api.py.") - - # create the schema - writer.add_schema(mindrecord_schema, "mindrecord_schema") - - # add the index - try: - index_fields = mr_api.mindrecord_index_fields - writer.add_index(index_fields) - except AttributeError: - print("Default index fields: all simple fields are indexes.") - - writer.open_and_set_header() - - task_list = list(range(num_tasks)) - - # set number of workers - num_workers = args.mindrecord_workers - - if num_tasks < 1: - num_tasks = 1 - - if num_workers > num_tasks: - num_workers = num_tasks - - if num_tasks > 1: - with Pool(num_workers) as p: - p.map(_exec_task, task_list) - else: - _exec_task(0, False) - - ret = writer.commit() - - os.remove("{}".format("mr_argument.pickle")) - - end_time = time.time() - print("--------------------------------------------") - print("END. Total time: {}".format(end_time - start_time)) - print("--------------------------------------------") diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 8718e9b871c9183fc9bccaad33392709fc81e6bd..338a17ac2decfabb5f67404079bd22816b2a0705 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -75,9 +75,12 @@ void BindShardWriter(py::module *m) { .def("set_header_size", &ShardWriter::set_header_size) .def("set_page_size", &ShardWriter::set_page_size) .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, - vector> &, bool, bool)) & - ShardWriter::WriteRawData) + .def("write_raw_data", + (MSRStatus(ShardWriter::*)(std::map> &, vector> &, bool)) & + ShardWriter::WriteRawData) + .def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map> &, + std::map> &, bool)) & + ShardWriter::WriteRawData) .def("commit", &ShardWriter::Commit); } diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index 70cfcdb6b7d6091405c125582d6d8b2701dfcf90..ca4d3bd66fa55deb91939eaa8a9ede907a84850f 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -121,10 +121,6 @@ class ShardHeader { std::vector SerializeHeader(); - MSRStatus PagesToFile(const std::string dump_file_name); - - MSRStatus FileToPages(const std::string dump_file_name); - private: MSRStatus InitializeHeader(const std::vector &headers); diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h index 78a434fc97c237ac39ba40d28e9face664de68f4..6a22f07700b99cb3c9e9375497546b7ee79918b2 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/mindrecord/include/shard_writer.h @@ -18,7 +18,6 @@ #define MINDRECORD_INCLUDE_SHARD_WRITER_H_ #include -#include #include #include #include @@ -88,7 +87,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); + bool sign = true); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -96,7 +95,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); + bool sign = true); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -104,8 +103,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true, - bool parallel_writer = false); + std::map> &blob_data, bool sign = true); private: /// \brief write shard header data to disk @@ -203,34 +201,7 @@ class ShardWriter { MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, std::map &err_raw_data); - /// \brief Lock writer and save pages info - int LockWriter(bool parallel_writer = false); - - /// \brief Unlock writer and save pages info - MSRStatus UnlockWriter(int fd, bool parallel_writer = false); - - /// \brief Check raw data before writing - MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, - bool sign, int *schema_count, int *row_count); - - /// \brief Get full path from file name - MSRStatus GetFullPathFromFileName(const std::vector &paths); - - /// \brief Open files - MSRStatus OpenDataFiles(bool append); - - /// \brief Remove lock file - MSRStatus RemoveLockFile(); - - /// \brief Remove lock file - MSRStatus InitLockFile(); - private: - const std::string kLockFileSuffix = "_Locker"; - const std::string kPageFileSuffix = "_Pages"; - std::string lock_file_; // lock file for parallel run - std::string pages_file_; // temporary file of pages info for parallel run - int shard_count_; // number of files uint64_t header_size_; // header size uint64_t page_size_; // page size @@ -240,7 +211,7 @@ class ShardWriter { std::vector raw_data_size_; // Raw data size std::vector blob_data_size_; // Blob data size - std::vector file_paths_; // file paths + std::vector file_paths_; // file paths std::vector> file_streams_; // file handles std::shared_ptr shard_header_; // shard headers diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index dc2743cdc7434cb6cd18e6c2dac9d9a0bfdeb49e..5a5cd7cbf3157a16f428b53443e553f7c4528af7 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -520,16 +520,13 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std for (int raw_page_id : raw_page_ids) { auto sql = GenerateRawSQL(fields_); if (sql.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw SQL failed"; return FAILED; } auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); if (data.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw data failed"; return FAILED; } if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { - MS_LOG(ERROR) << "Execute SQL failed"; return FAILED; } MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index ac95e622c9ae5d2bac2866d7b54b1e389692b29c..864e6697d03eb932298899f44975ed4103ef38e4 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -40,7 +40,17 @@ ShardWriter::~ShardWriter() { } } -MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { +MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { + shard_count_ = paths.size(); + if (shard_count_ > kMaxShardCount || shard_count_ == 0) { + MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; + return FAILED; + } + if (schema_count_ > kMaxSchemaCount) { + MS_LOG(ERROR) << "The schema Count greater than max value."; + return FAILED; + } + // Get full path from file name for (const auto &path : paths) { if (!CheckIsValidUtf8(path)) { @@ -50,7 +60,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &p char resolved_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Secure func failed"; + MS_LOG(ERROR) << "Securec func failed"; return FAILED; } #if defined(_WIN32) || defined(_WIN64) @@ -72,10 +82,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &p #endif file_paths_.emplace_back(string(resolved_path)); } - return SUCCESS; -} -MSRStatus ShardWriter::OpenDataFiles(bool append) { // Open files for (const auto &file : file_paths_) { std::shared_ptr fs = std::make_shared(); @@ -109,67 +116,6 @@ MSRStatus ShardWriter::OpenDataFiles(bool append) { return SUCCESS; } -MSRStatus ShardWriter::RemoveLockFile() { - // Remove temporary file - int ret = std::remove(pages_file_.c_str()); - if (ret == 0) { - MS_LOG(DEBUG) << "Remove page file."; - } - - ret = std::remove(lock_file_.c_str()); - if (ret == 0) { - MS_LOG(DEBUG) << "Remove lock file."; - } - return SUCCESS; -} - -MSRStatus ShardWriter::InitLockFile() { - if (file_paths_.size() == 0) { - MS_LOG(ERROR) << "File path not initialized."; - return FAILED; - } - - lock_file_ = file_paths_[0] + kLockFileSuffix; - pages_file_ = file_paths_[0] + kPageFileSuffix; - - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove file failed."; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { - shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value."; - return FAILED; - } - - // Get full path from file name - if (GetFullPathFromFileName(paths) == FAILED) { - MS_LOG(ERROR) << "Get full path from file name failed."; - return FAILED; - } - - // Open files - if (OpenDataFiles(append) == FAILED) { - MS_LOG(ERROR) << "Open data files failed."; - return FAILED; - } - - // Init lock file - if (InitLockFile() == FAILED) { - MS_LOG(ERROR) << "Init lock file failed."; - return FAILED; - } - return SUCCESS; -} - MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (!IsLegalFile(path)) { return FAILED; @@ -197,28 +143,11 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { } MSRStatus ShardWriter::Commit() { - // Read pages file - std::ifstream page_file(pages_file_.c_str()); - if (page_file.good()) { - page_file.close(); - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return FAILED; - } - } - if (WriteShardHeader() == FAILED) { MS_LOG(ERROR) << "Write metadata failed"; return FAILED; } MS_LOG(INFO) << "Write metadata successfully."; - - // Remove lock file - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove lock file failed."; - return FAILED; - } - return SUCCESS; } @@ -526,65 +455,15 @@ void ShardWriter::FillArray(int start, int end, std::map> } } -int ShardWriter::LockWriter(bool parallel_writer) { - if (!parallel_writer) { - return 0; - } - const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); - if (fd >= 0) { - flock(fd, LOCK_EX); - } else { - MS_LOG(ERROR) << "Shard writer failed when locking file"; - return -1; - } - - // Open files - file_streams_.clear(); - for (const auto &file : file_paths_) { - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { - MS_LOG(ERROR) << "File could not opened"; - return -1; - } - file_streams_.push_back(fs); - } - - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return -1; - } - return fd; -} - -MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { - if (!parallel_writer) { - return SUCCESS; - } - - if (shard_header_->PagesToFile(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Write pages to file failed"; - return FAILED; - } - - for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { - file_streams_[i]->close(); - } - - flock(fd, LOCK_UN); - close(fd); - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, - std::vector> &blob_data, bool sign, int *schema_count, - int *row_count) { +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign) { // check the free disk size auto st_space = GetDiskSize(file_paths_[0], kFreeSize); if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { MS_LOG(ERROR) << "IO error / there is no free disk to be used"; return FAILED; } + // Add 4-bytes dummy blob data if no any blob fields if (blob_data.size() == 0 && raw_data.size() > 0) { blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); @@ -600,29 +479,10 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map MS_LOG(ERROR) << "Validate raw data failed"; return FAILED; } - *schema_count = std::get<1>(v); - *row_count = std::get<2>(v); - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign, bool parallel_writer) { - // Lock Writer if loading data parallel - int fd = LockWriter(parallel_writer); - if (fd < 0) { - MS_LOG(ERROR) << "Lock writer failed"; - return FAILED; - } // Get the count of schemas and rows - int schema_count = 0; - int row_count = 0; - - // Serialize raw data - if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { - MS_LOG(ERROR) << "Check raw data failed"; - return FAILED; - } + int schema_count = std::get<1>(v); + int row_count = std::get<2>(v); if (row_count == kInt0) { MS_LOG(INFO) << "Raw data size is 0."; @@ -656,17 +516,11 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d } MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; - if (UnlockWriter(fd, parallel_writer) == FAILED) { - MS_LOG(ERROR) << "Unlock writer failed"; - return FAILED; - } - return SUCCESS; } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign, - bool parallel_writer) { + std::map> &blob_data, bool sign) { std::map> raw_data_json; std::map> blob_data_json; @@ -700,11 +554,11 @@ MSRStatus ShardWriter::WriteRawData(std::map> MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; return FAILED; } - return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); + return WriteRawData(raw_data_json, bin_blob_data, sign); } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign, bool parallel_writer) { + vector> &blob_data, bool sign) { std::map> raw_data_json; (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), [](const std::pair> &pair) { @@ -714,7 +568,7 @@ MSRStatus ShardWriter::WriteRawData(std::map> [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); return std::make_pair(pair.first, std::move(json_raw_data)); }); - return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); + return WriteRawData(raw_data_json, blob_data, sign); } MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 26008e3ca96482cb830bd53797a740cc5f0cc0d8..57b2e5fa9eaf50f76854e01a3df60689f67912b9 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -677,43 +677,5 @@ std::pair, MSRStatus> ShardHeader::GetStatisticByID( } return std::make_pair(statistics_.at(statistic_id), SUCCESS); } - -MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { - // write header content to file, dump whatever is in the file before - std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); - if (page_out_handle.fail()) { - MS_LOG(ERROR) << "Failed in opening page file"; - return FAILED; - } - - auto pages = SerializePage(); - for (const auto &shard_pages : pages) { - page_out_handle << shard_pages << "\n"; - } - - page_out_handle.close(); - return SUCCESS; -} - -MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { - for (auto &v : pages_) { // clean pages - v.clear(); - } - // attempt to open the file contains the page in json - std::ifstream page_in_handle(dump_file_name.c_str()); - - if (!page_in_handle.good()) { - MS_LOG(INFO) << "No page file exists."; - return SUCCESS; - } - - std::string line; - while (std::getline(page_in_handle, line)) { - ParsePage(json::parse(line)); - } - - page_in_handle.close(); - return SUCCESS; -} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index 62bcc2df79dcb240c1b2b7218b7a07df186a20a0..90bca480382507b1e414e4414e74fbbb2b2274c8 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -200,24 +200,13 @@ class FileWriter: raw_data.pop(i) logger.warning(v) - def open_and_set_header(self): - """ - Open writer and set header - - """ - if not self._writer.is_open: - self._writer.open(self._paths) - if not self._writer.get_shard_header(): - self._writer.set_shard_header(self._header) - - def write_raw_data(self, raw_data, parallel_writer=False): + def write_raw_data(self, raw_data): """ Write raw data and generate sequential pair of MindRecord File and \ validate data based on predefined schema by default. Args: raw_data (list[dict]): List of raw data. - parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). Raises: ParamTypeError: If index field is invalid. @@ -236,7 +225,7 @@ class FileWriter: if not isinstance(each_raw, dict): raise ParamTypeError('raw_data item', 'dict') self._verify_based_on_schema(raw_data) - return self._writer.write_raw_data(raw_data, True, parallel_writer) + return self._writer.write_raw_data(raw_data, True) def set_header_size(self, header_size): """ diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 0913201861ce8d77c2e44ee8a2e4169faa626b69..0ef23d4ce66ac34d0fe8b1944162291b6f962aa2 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -135,7 +135,7 @@ class ShardWriter: def get_shard_header(self): return self._header - def write_raw_data(self, data, validate=True, parallel_writer=False): + def write_raw_data(self, data, validate=True): """ Write raw data of cv dataset. @@ -145,7 +145,6 @@ class ShardWriter: Args: data (list[dict]): List of raw data. validate (bool, optional): verify data according schema if it equals to True. - parallel_writer (bool, optional): Load data parallel if it equals to True. Returns: MSRStatus, SUCCESS or FAILED. @@ -166,7 +165,7 @@ class ShardWriter: if row_raw: raw_data.append(row_raw) raw_data = {0: raw_data} if raw_data else {} - ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer) + ret = self._writer.write_raw_data(raw_data, blob_data, validate) if ret != ms.MSRStatus.SUCCESS: logger.error("Failed to write dataset.") raise MRMWriteDatasetError