Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
bd7ac259
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bd7ac259
编写于
4月 09, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 09, 2020
浏览文件
操作
浏览文件
下载
差异文件
!147 Parallelize mindrecord index writer via std::thread
Merge pull request !147 from ZiruiWu/master
上级
3e369823
5637f806
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
72 addition
and
36 deletion
+72
-36
mindspore/ccsrc/mindrecord/include/shard_index_generator.h
mindspore/ccsrc/mindrecord/include/shard_index_generator.h
+7
-3
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
+65
-33
未找到文件。
mindspore/ccsrc/mindrecord/include/shard_index_generator.h
浏览文件 @
bd7ac259
...
...
@@ -85,14 +85,14 @@ class ShardIndexGenerator {
/// \param sql
/// \param data
/// \return
MSRStatus
BindParam
a
terExecuteSQL
(
MSRStatus
BindParam
e
terExecuteSQL
(
sqlite3
*
db
,
const
std
::
string
&
sql
,
const
std
::
vector
<
std
::
vector
<
std
::
tuple
<
std
::
string
,
std
::
string
,
std
::
string
>>>
&
data
);
INDEX_FIELDS
GenerateIndexFields
(
const
std
::
vector
<
json
>
&
schema_detail
);
MSRStatus
ExcuteTransaction
(
const
int
&
shard_no
,
const
std
::
pair
<
MSRStatus
,
sqlite3
*>
&
db
,
const
std
::
vector
<
int
>
&
raw_page_ids
,
const
std
::
map
<
int
,
int
>
&
blob_id_to_page_id
);
MSRStatus
Ex
e
cuteTransaction
(
const
int
&
shard_no
,
const
std
::
pair
<
MSRStatus
,
sqlite3
*>
&
db
,
const
std
::
vector
<
int
>
&
raw_page_ids
,
const
std
::
map
<
int
,
int
>
&
blob_id_to_page_id
);
MSRStatus
CreateShardNameTable
(
sqlite3
*
db
,
const
std
::
string
&
shard_name
);
...
...
@@ -103,12 +103,16 @@ class ShardIndexGenerator {
void
AddIndexFieldByRawData
(
const
std
::
vector
<
json
>
&
schema_detail
,
std
::
vector
<
std
::
tuple
<
std
::
string
,
std
::
string
,
std
::
string
>>
&
row_data
);
void
DatabaseWriter
();
// worker thread
std
::
string
file_path_
;
bool
append_
;
ShardHeader
shard_header_
;
uint64_t
page_size_
;
uint64_t
header_size_
;
int
schema_count_
;
std
::
atomic_int
task_
;
std
::
atomic_bool
write_success_
;
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
fields_
;
};
}
// namespace mindrecord
...
...
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
浏览文件 @
bd7ac259
...
...
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thread>
#include "mindrecord/include/shard_index_generator.h"
#include "common/utils.h"
...
...
@@ -26,7 +27,13 @@ using mindspore::MsLogLevel::INFO;
namespace
mindspore
{
namespace
mindrecord
{
ShardIndexGenerator
::
ShardIndexGenerator
(
const
std
::
string
&
file_path
,
bool
append
)
:
file_path_
(
file_path
),
append_
(
append
),
page_size_
(
0
),
header_size_
(
0
),
schema_count_
(
0
)
{}
:
file_path_
(
file_path
),
append_
(
append
),
page_size_
(
0
),
header_size_
(
0
),
schema_count_
(
0
),
task_
(
0
),
write_success_
(
true
)
{}
MSRStatus
ShardIndexGenerator
::
Build
()
{
ShardHeader
header
=
ShardHeader
();
...
...
@@ -284,7 +291,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL(
return
{
SUCCESS
,
sql
};
}
MSRStatus
ShardIndexGenerator
::
BindParam
a
terExecuteSQL
(
MSRStatus
ShardIndexGenerator
::
BindParam
e
terExecuteSQL
(
sqlite3
*
db
,
const
std
::
string
&
sql
,
const
std
::
vector
<
std
::
vector
<
std
::
tuple
<
std
::
string
,
std
::
string
,
std
::
string
>>>
&
data
)
{
sqlite3_stmt
*
stmt
=
nullptr
;
...
...
@@ -471,9 +478,9 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
return
{
SUCCESS
,
std
::
move
(
fields
)};
}
MSRStatus
ShardIndexGenerator
::
ExcuteTransaction
(
const
int
&
shard_no
,
const
std
::
pair
<
MSRStatus
,
sqlite3
*>
&
db
,
const
std
::
vector
<
int
>
&
raw_page_ids
,
const
std
::
map
<
int
,
int
>
&
blob_id_to_page_id
)
{
MSRStatus
ShardIndexGenerator
::
Ex
e
cuteTransaction
(
const
int
&
shard_no
,
const
std
::
pair
<
MSRStatus
,
sqlite3
*>
&
db
,
const
std
::
vector
<
int
>
&
raw_page_ids
,
const
std
::
map
<
int
,
int
>
&
blob_id_to_page_id
)
{
// Add index data to database
std
::
string
shard_address
=
shard_header_
.
get_shard_address_by_id
(
shard_no
);
if
(
shard_address
.
empty
())
{
...
...
@@ -493,7 +500,7 @@ MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std:
if
(
data
.
first
!=
SUCCESS
)
{
return
FAILED
;
}
if
(
BindParam
a
terExecuteSQL
(
db
.
second
,
sql
.
second
,
data
.
second
)
==
FAILED
)
{
if
(
BindParam
e
terExecuteSQL
(
db
.
second
,
sql
.
second
,
data
.
second
)
==
FAILED
)
{
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
"Insert "
<<
data
.
second
.
size
()
<<
" rows to index db."
;
...
...
@@ -514,37 +521,62 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
page_size_
=
shard_header_
.
get_page_size
();
header_size_
=
shard_header_
.
get_header_size
();
schema_count_
=
shard_header_
.
get_schema_count
();
if
(
shard_header_
.
get_shard_count
()
<=
kMaxShardCount
)
{
// Create one database per shard
for
(
int
shard_no
=
0
;
shard_no
<
shard_header_
.
get_shard_count
();
++
shard_no
)
{
// Create database
auto
db
=
CreateDatabase
(
shard_no
);
if
(
db
.
first
!=
SUCCESS
||
db
.
second
==
nullptr
)
{
return
FAILED
;
}
MS_LOG
(
INFO
)
<<
"Init index db for shard: "
<<
shard_no
<<
" successfully."
;
// Pre-processing page information
auto
total_pages
=
shard_header_
.
GetLastPageId
(
shard_no
)
+
1
;
std
::
map
<
int
,
int
>
blob_id_to_page_id
;
std
::
vector
<
int
>
raw_page_ids
;
for
(
uint64_t
i
=
0
;
i
<
total_pages
;
++
i
)
{
std
::
shared_ptr
<
Page
>
cur_page
=
shard_header_
.
GetPage
(
shard_no
,
i
).
first
;
if
(
cur_page
->
get_page_type
()
==
"RAW_DATA"
)
{
raw_page_ids
.
push_back
(
i
);
}
else
if
(
cur_page
->
get_page_type
()
==
"BLOB_DATA"
)
{
blob_id_to_page_id
[
cur_page
->
get_page_type_id
()]
=
i
;
}
}
if
(
shard_header_
.
get_shard_count
()
>
kMaxShardCount
)
{
MS_LOG
(
ERROR
)
<<
"num shards: "
<<
shard_header_
.
get_shard_count
()
<<
" exceeds max count:"
<<
kMaxSchemaCount
;
return
FAILED
;
}
task_
=
0
;
// set two atomic vars to initial value
write_success_
=
true
;
if
(
ExcuteTransaction
(
shard_no
,
db
,
raw_page_ids
,
blob_id_to_page_id
)
!=
SUCCESS
)
{
return
FAILED
;
// spawn half the physical threads or total number of shards whichever is smaller
const
unsigned
int
num_workers
=
std
::
min
(
std
::
thread
::
hardware_concurrency
()
/
2
+
1
,
static_cast
<
unsigned
int
>
(
shard_header_
.
get_shard_count
()));
std
::
vector
<
std
::
thread
>
threads
;
threads
.
reserve
(
num_workers
);
for
(
size_t
t
=
0
;
t
<
threads
.
capacity
();
t
++
)
{
threads
.
emplace_back
(
std
::
thread
(
&
ShardIndexGenerator
::
DatabaseWriter
,
this
));
}
for
(
size_t
t
=
0
;
t
<
threads
.
capacity
();
t
++
)
{
threads
[
t
].
join
();
}
return
write_success_
?
SUCCESS
:
FAILED
;
}
void
ShardIndexGenerator
::
DatabaseWriter
()
{
int
shard_no
=
task_
++
;
while
(
shard_no
<
shard_header_
.
get_shard_count
())
{
auto
db
=
CreateDatabase
(
shard_no
);
if
(
db
.
first
!=
SUCCESS
||
db
.
second
==
nullptr
||
write_success_
==
false
)
{
write_success_
=
false
;
return
;
}
MS_LOG
(
INFO
)
<<
"Init index db for shard: "
<<
shard_no
<<
" successfully."
;
// Pre-processing page information
auto
total_pages
=
shard_header_
.
GetLastPageId
(
shard_no
)
+
1
;
std
::
map
<
int
,
int
>
blob_id_to_page_id
;
std
::
vector
<
int
>
raw_page_ids
;
for
(
uint64_t
i
=
0
;
i
<
total_pages
;
++
i
)
{
std
::
shared_ptr
<
Page
>
cur_page
=
shard_header_
.
GetPage
(
shard_no
,
i
).
first
;
if
(
cur_page
->
get_page_type
()
==
"RAW_DATA"
)
{
raw_page_ids
.
push_back
(
i
);
}
else
if
(
cur_page
->
get_page_type
()
==
"BLOB_DATA"
)
{
blob_id_to_page_id
[
cur_page
->
get_page_type_id
()]
=
i
;
}
MS_LOG
(
INFO
)
<<
"Generate index db for shard: "
<<
shard_no
<<
" successfully."
;
}
if
(
ExecuteTransaction
(
shard_no
,
db
,
raw_page_ids
,
blob_id_to_page_id
)
!=
SUCCESS
)
{
write_success_
=
false
;
return
;
}
MS_LOG
(
INFO
)
<<
"Generate index db for shard: "
<<
shard_no
<<
" successfully."
;
shard_no
=
task_
++
;
}
return
SUCCESS
;
}
}
// namespace mindrecord
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录