Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
c71279bc
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c71279bc
编写于
12月 13, 2018
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code style for async_executor.h and async_executor.cc
上级
33ee5cad
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
79 addition
and
47 deletion
+79
-47
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+64
-37
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+15
-10
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
c71279bc
...
...
@@ -66,15 +66,20 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
}
void
AsyncExecutor
::
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
)
{
_pslib_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
new
paddle
::
distributed
::
PSlib
());
_pslib_ptr
->
init_server
(
dist_desc
,
index
);
//TODO done
_pslib_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
new
paddle
::
distributed
::
PSlib
());
_pslib_ptr
->
init_server
(
dist_desc
,
index
);
InitParamConfig
();
}
void
AsyncExecutor
::
InitWorker
(
const
std
::
string
&
dist_desc
,
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
,
int
index
)
{
_pslib_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
new
paddle
::
distributed
::
PSlib
());
_pslib_ptr
->
init_worker
(
dist_desc
,
host_sign_list
.
data
(),
node_num
,
index
);
//TODO done
void
AsyncExecutor
::
InitWorker
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
,
int
index
)
{
_pslib_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
new
paddle
::
distributed
::
PSlib
());
_pslib_ptr
->
init_worker
(
dist_desc
,
host_sign_list
.
data
(),
node_num
,
index
);
InitParamConfig
();
}
...
...
@@ -87,43 +92,65 @@ void AsyncExecutor::StopServer() {
_pslib_ptr
->
stop_server
();
}
void
AsyncExecutor
::
GatherServers
(
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
)
{
void
AsyncExecutor
::
GatherServers
(
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
)
{
_pslib_ptr
->
gather_servers
(
host_sign_list
.
data
(),
node_num
);
}
void
AsyncExecutor
::
InitParamConfig
()
{
for
(
int
i
=
0
;
i
<
_pslib_ptr
->
get_param
()
->
server_param
().
downpour_server_param
().
downpour_table_param_size
();
++
i
)
{
if
(
_pslib_ptr
->
get_param
()
->
server_param
().
downpour_server_param
().
downpour_table_param
(
i
).
table_class
().
find
(
"SparseTable"
)
!=
-
1
)
{
_param_config
.
fea_dim
=
_pslib_ptr
->
get_param
()
->
server_param
().
downpour_server_param
().
downpour_table_param
(
i
).
accessor
().
fea_dim
();
//TODO
for
(
int
i
=
0
;
i
<
_pslib_ptr
->
get_param
()
->
server_param
().
\
downpour_server_param
().
\
downpour_table_param_size
();
++
i
)
{
if
(
_pslib_ptr
->
get_param
()
->
server_param
().
\
downpour_server_param
().
downpour_table_param
(
i
).
\
table_class
().
find
(
"SparseTable"
)
!=
-
1
)
{
_param_config
.
fea_dim
=
_pslib_ptr
->
get_param
()
->
server_param
().
\
downpour_server_param
().
\
downpour_table_param
(
i
).
\
accessor
().
fea_dim
();
break
;
}
}
_param_config
.
slot_dim
=
_param_config
.
fea_dim
-
2
;
//TODO
_param_config
.
tmp_push_dense_wait_times
=
(
int32_t
)(
_pslib_ptr
->
get_param
()
->
trainer_param
().
push_dense_per_batch
());
_param_config
.
tmp_push_sparse_wait_times
=
(
int32_t
)(
_pslib_ptr
->
get_param
()
->
trainer_param
().
push_sparse_per_batch
());
for
(
auto
t
=
0u
;
t
<
_pslib_ptr
->
get_param
()
->
trainer_param
().
skip_op_size
();
++
t
)
{
_param_config
.
skip_op
.
push_back
(
_pslib_ptr
->
get_param
()
->
trainer_param
().
skip_op
(
t
));
_param_config
.
slot_dim
=
_param_config
.
fea_dim
-
2
;
_param_config
.
tmp_push_dense_wait_times
=
static_cast
<
int32_t
>
(
_pslib_ptr
->
get_param
()
->
trainer_param
().
push_dense_per_batch
());
_param_config
.
tmp_push_sparse_wait_times
=
static_cast
<
int32_t
>
(
_pslib_ptr
->
get_param
()
->
trainer_param
().
push_sparse_per_batch
());
for
(
auto
t
=
0u
;
t
<
_pslib_ptr
->
get_param
()
->
trainer_param
().
skip_op_size
();
++
t
)
{
_param_config
.
skip_op
.
push_back
(
_pslib_ptr
->
get_param
()
->
trainer_param
().
skip_op
(
t
));
}
//sparse
for
(
auto
t
=
0u
;
t
<
_pslib_ptr
->
get_param
()
->
trainer_param
().
sparse_table_size
();
++
t
)
{
for
(
auto
t
=
0u
;
t
<
_pslib_ptr
->
get_param
()
->
trainer_param
().
sparse_table_size
();
++
t
)
{
auto
&
table
=
_pslib_ptr
->
get_param
()
->
trainer_param
().
sparse_table
(
t
);
std
::
vector
<
std
::
string
>
tmp_sparse_variable_name
;
for
(
int
i
=
0u
;
i
<
table
.
slot_value_size
();
++
i
)
{
tmp_sparse_variable_name
.
push_back
(
table
.
slot_value
(
i
));
_param_config
.
slot_alias_to_table
[
table
.
slot_key
(
i
)]
=
table
.
table_id
();
_param_config
.
slot_alias_to_table
[
table
.
slot_key
(
i
)]
=
table
.
table_id
();
}
std
::
vector
<
std
::
string
>
tmp_sparse_gradient_variable_name
;
for
(
auto
i
=
0u
;
i
<
table
.
slot_gradient_size
();
++
i
)
{
tmp_sparse_gradient_variable_name
.
push_back
(
table
.
slot_gradient
(
i
));
}
_param_config
.
slot_input_vec
[
table
.
table_id
()]
=
std
::
move
(
tmp_sparse_variable_name
);
_param_config
.
gradient_var
[
table
.
table_id
()]
=
std
::
move
(
tmp_sparse_gradient_variable_name
);
_param_config
.
slot_input_vec
[
table
.
table_id
()]
=
std
::
move
(
tmp_sparse_variable_name
);
_param_config
.
gradient_var
[
table
.
table_id
()]
=
std
::
move
(
tmp_sparse_gradient_variable_name
);
_param_config
.
sparse_table_id
.
push_back
(
table
.
table_id
());
}
//dense
for
(
auto
t
=
0u
;
t
<
_pslib_ptr
->
get_param
()
->
trainer_param
().
dense_table_size
();
++
t
)
{
for
(
auto
t
=
0u
;
t
<
_pslib_ptr
->
get_param
()
->
trainer_param
().
dense_table_size
();
++
t
)
{
auto
&
table
=
_pslib_ptr
->
get_param
()
->
trainer_param
().
dense_table
(
t
);
std
::
vector
<
std
::
string
>
tmp_dense_variable_name
;
for
(
int
i
=
0u
;
i
<
table
.
dense_variable_name_size
();
++
i
)
{
...
...
@@ -134,20 +161,18 @@ void AsyncExecutor::InitParamConfig() {
tmp_dense_gradient_variable_name
.
push_back
(
table
.
dense_gradient_variable_name
(
i
));
}
_param_config
.
dense_variable_name
[
table
.
table_id
()]
=
std
::
move
(
tmp_dense_variable_name
);
_param_config
.
dense_gradient_variable_name
[
table
.
table_id
()]
=
std
::
move
(
tmp_dense_gradient_variable_name
);
_param_config
.
dense_variable_name
[
table
.
table_id
()]
=
std
::
move
(
tmp_dense_variable_name
);
_param_config
.
dense_gradient_variable_name
[
table
.
table_id
()]
=
std
::
move
(
tmp_dense_gradient_variable_name
);
_param_config
.
dense_table_id
.
push_back
(
table
.
table_id
());
_param_config
.
dense_table_size
.
push_back
(
table
.
fea_dim
());
//TODO
_param_config
.
dense_table_size
.
push_back
(
table
.
fea_dim
());
}
}
void
AsyncExecutor
::
InitModel
()
{
//TODO only rank = 0 do this
//std::vector<int> all_dense_table_id; //TODO
//all_dense_table_id.push_back(0); //done
for
(
auto
table_id
:
_param_config
.
dense_table_id
)
{
for
(
auto
table_id
:
_param_config
.
dense_table_id
)
{
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
//std::vector<std::string> variables; //TODO
for
(
auto
&
t
:
_param_config
.
dense_variable_name
[
table_id
])
{
Variable
*
var
=
root_scope_
->
FindVar
(
t
);
CHECK
(
var
!=
nullptr
)
<<
"var["
<<
t
<<
"] not found"
;
...
...
@@ -169,13 +194,15 @@ void AsyncExecutor::InitModel() {
regions
.
emplace_back
(
std
::
move
(
reg
));
}
auto
push_status
=
_pslib_ptr
->
_worker_ptr
->
push_dense_param
(
regions
.
data
(),
regions
.
size
(),
table_id
);
auto
push_status
=
_pslib_ptr
->
_worker_ptr
->
push_dense_param
(
regions
.
data
(),
regions
.
size
(),
table_id
);
push_status
.
wait
();
auto
status
=
push_status
.
get
();
if
(
status
!=
0
)
{
LOG
(
FATAL
)
<<
"push dense param failed, status["
<<
status
<<
"]"
;
exit
(
-
1
);
}
}
}
}
...
...
@@ -185,7 +212,7 @@ void AsyncExecutor::SaveModel(const std::string& path) {
ret
=
_pslib_ptr
->
_worker_ptr
->
save
(
path
,
0
);
ret
.
wait
();
int32_t
feasign_cnt
=
ret
.
get
();
if
(
feasign_cnt
==
-
1
)
{
// TODO should be feasign_cnt < 0, because server bug
if
(
feasign_cnt
==
-
1
)
{
// (colourful-tree) TODO should be feasign_cnt < 0
LOG
(
FATAL
)
<<
"save model failed"
;
exit
(
-
1
);
}
...
...
@@ -195,13 +222,13 @@ void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if
(
mode
==
"mpi"
)
{
DensePullThreadParam
param
;
param
.
ps_client
=
_pslib_ptr
->
_worker_ptr
;;
param
.
threshold
=
1
;
//GlobalConfig::instance().pull_dense_per_batch; //TODO
param
.
threshold
=
1
;
param
.
training_thread_num
=
actual_thread_num
;
param
.
root_scope
=
root_scope_
;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param
.
dense_params
=
&
_param_config
.
dense_variable_name
;
_pull_dense_thread
=
std
::
shared_ptr
<
DensePullThread
>
(
new
DensePullThread
(
param
));
_pull_dense_thread
=
std
::
shared_ptr
<
DensePullThread
>
(
new
DensePullThread
(
param
));
_pull_dense_thread
->
start
();
}
}
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
c71279bc
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <time.h>
#include <map>
#include <memory>
#include <mutex> // NOLINT
...
...
@@ -22,8 +23,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include <typeinfo>
#include <vector>
#include <random> //local_random_engine
#include <time.h> //local_random_engine
#include <random> // local_random_engine
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
...
...
@@ -43,9 +43,10 @@ inline std::default_random_engine& local_random_engine() {
struct
engine_wrapper_t
{
std
::
default_random_engine
engine
;
engine_wrapper_t
()
{
static
std
::
atomic
<
unsigned
long
>
x
(
0
);
std
::
seed_seq
sseq
=
{
x
++
,
x
++
,
x
++
,
(
unsigned
long
)(
current_realtime
()
*
1000
)};
engine
.
seed
(
sseq
);
static
std
::
atomic
<
uint64
>
x
(
0
);
std
::
seed_seq
sseq
=
{
x
++
,
x
++
,
x
++
,
static_cast
<
uint64
>
(
current_realtime
()
*
1000
)};
engine
.
seed
(
sseq
);
}
};
thread_local
engine_wrapper_t
r
;
...
...
@@ -61,18 +62,20 @@ class AsyncExecutor {
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
fetch_names
,
const
std
::
string
&
mode
,
const
std
::
string
&
mode
,
const
bool
debug
=
false
);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
);
void
InitWorker
(
const
std
::
string
&
dist_desc
,
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
,
int
index
);
//void ConfigWorker() {}
void
InitWorker
(
const
std
::
string
&
dist_desc
,
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
,
int
index
);
uint64_t
StartServer
();
void
StopServer
();
void
GatherServers
(
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
);
void
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
);
void
InitModel
();
void
SaveModel
(
const
std
::
string
&
path
);
void
InitParamConfig
();
private:
void
CreateThreads
(
ExecutorThreadWorker
*
worker
,
const
ProgramDesc
&
main_program
,
...
...
@@ -81,6 +84,7 @@ class AsyncExecutor {
Scope
*
root_scope
,
const
int
thread_index
,
const
bool
debug
);
void
PrepareDenseThread
(
const
std
::
string
&
mode
);
public:
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
_pslib_ptr
;
std
::
shared_ptr
<
DensePullThread
>
_pull_dense_thread
;
...
...
@@ -88,6 +92,7 @@ class AsyncExecutor {
platform
::
Place
place_
;
AsyncWorkerParamConfig
_param_config
;
private:
int
actual_thread_num
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录