Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
10ed9e0a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
10ed9e0a
编写于
12月 11, 2018
作者:
H
heqiaozhi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
download & run & instance
上级
57ac412b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
39 addition
and
25 deletion
+39
-25
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+23
-15
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+2
-1
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+14
-9
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
10ed9e0a
...
@@ -191,18 +191,19 @@ void AsyncExecutor::SaveModel(const std::string& path) {
...
@@ -191,18 +191,19 @@ void AsyncExecutor::SaveModel(const std::string& path) {
}
}
}
}
void
AsyncExecutor
::
PrepareDenseThread
()
{
void
AsyncExecutor
::
PrepareDenseThread
(
const
std
::
string
&
mode
)
{
DensePullThreadParam
param
;
if
(
mode
==
"mpi"
)
{
param
.
ps_client
=
_pslib_ptr
->
_worker_ptr
;;
DensePullThreadParam
param
;
param
.
threshold
=
1
;
//GlobalConfig::instance().pull_dense_per_batch; //TODO
param
.
ps_client
=
_pslib_ptr
->
_worker_ptr
;;
param
.
training_thread_num
=
actual_thread_num
;
param
.
threshold
=
1
;
//GlobalConfig::instance().pull_dense_per_batch; //TODO
param
.
root_scope
=
root_scope_
;
param
.
training_thread_num
=
actual_thread_num
;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param
.
root_scope
=
root_scope_
;
param
.
dense_params
=
&
_param_config
.
dense_variable_name
;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param
.
dense_params
=
&
_param_config
.
dense_variable_name
;
_pull_dense_thread
=
std
::
shared_ptr
<
DensePullThread
>
(
new
DensePullThread
(
param
));
_pull_dense_thread
->
start
();
_pull_dense_thread
=
std
::
shared_ptr
<
DensePullThread
>
(
new
DensePullThread
(
param
));
_pull_dense_thread
->
start
();
}
}
}
void
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
void
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
...
@@ -210,6 +211,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...
@@ -210,6 +211,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
const
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
const
std
::
string
&
mode
,
const
bool
debug
)
{
const
bool
debug
)
{
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
thread
>
threads
;
...
@@ -251,11 +253,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...
@@ -251,11 +253,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed
// todo: should be factory method for creating datafeed
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>>
readers
;
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>>
readers
;
PrepareReaders
(
readers
,
actual_thread_num
,
data_feed_desc
,
filelist
);
PrepareReaders
(
readers
,
actual_thread_num
,
data_feed_desc
,
filelist
);
PrepareDenseThread
();
PrepareDenseThread
(
mode
);
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>>
workers
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>>
workers
;
workers
.
resize
(
actual_thread_num
);
workers
.
resize
(
actual_thread_num
);
for
(
auto
&
worker
:
workers
)
{
for
(
auto
&
worker
:
workers
)
{
worker
.
reset
(
new
AsyncExecutorThreadWorker
);
if
(
mode
==
"mpi"
)
{
worker
.
reset
(
new
AsyncExecutorThreadWorker
);
}
else
{
worker
.
reset
(
new
ExecutorThreadWorker
);
}
}
}
// prepare thread resource here
// prepare thread resource here
...
@@ -274,7 +280,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...
@@ -274,7 +280,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for
(
auto
&
th
:
threads
)
{
for
(
auto
&
th
:
threads
)
{
th
.
join
();
th
.
join
();
}
}
_pull_dense_thread
->
stop
();
if
(
mode
==
"mpi"
)
{
_pull_dense_thread
->
stop
();
}
root_scope_
->
DropKids
();
root_scope_
->
DropKids
();
return
;
return
;
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
10ed9e0a
...
@@ -61,6 +61,7 @@ class AsyncExecutor {
...
@@ -61,6 +61,7 @@ class AsyncExecutor {
const
std
::
vector
<
std
::
string
>&
filelist
,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
const
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
fetch_names
,
const
std
::
vector
<
std
::
string
>&
fetch_names
,
const
std
::
string
&
mode
,
const
bool
debug
=
false
);
const
bool
debug
=
false
);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
);
void
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
);
...
@@ -79,7 +80,7 @@ class AsyncExecutor {
...
@@ -79,7 +80,7 @@ class AsyncExecutor {
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
*
root_scope
,
const
int
thread_index
,
Scope
*
root_scope
,
const
int
thread_index
,
const
bool
debug
);
const
bool
debug
);
void
PrepareDenseThread
();
void
PrepareDenseThread
(
const
std
::
string
&
mode
);
public:
public:
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
_pslib_ptr
;
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
_pslib_ptr
;
std
::
shared_ptr
<
DensePullThread
>
_pull_dense_thread
;
std
::
shared_ptr
<
DensePullThread
>
_pull_dense_thread
;
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
10ed9e0a
...
@@ -87,9 +87,8 @@ class AsyncExecutor(object):
...
@@ -87,9 +87,8 @@ class AsyncExecutor(object):
scope
=
global_scope
()
scope
=
global_scope
()
self
.
executor
=
core
.
AsyncExecutor
(
scope
,
p
)
self
.
executor
=
core
.
AsyncExecutor
(
scope
,
p
)
self
.
instance
=
ps_instance
.
PaddlePSInstance
(
1
,
2
)
def
run
(
self
,
program
,
data_feed
,
filelist
,
thread_num
,
fetch
,
debug
=
False
):
def
run
(
self
,
program
,
data_feed
,
filelist
,
thread_num
,
fetch
,
mode
=
""
,
debug
=
False
):
"""
"""
Run program by this AsyncExecutor. Training dataset will be in filelist.
Run program by this AsyncExecutor. Training dataset will be in filelist.
Users can also inspect certain variables by naming them in parameter
Users can also inspect certain variables by naming them in parameter
...
@@ -151,10 +150,11 @@ class AsyncExecutor(object):
...
@@ -151,10 +150,11 @@ class AsyncExecutor(object):
self
.
executor
.
run_from_files
(
program_desc
,
self
.
executor
.
run_from_files
(
program_desc
,
data_feed
.
desc
(),
filelist
,
thread_num
,
data_feed
.
desc
(),
filelist
,
thread_num
,
fetch_var_names
,
debug
)
fetch_var_names
,
mode
,
debug
)
def
download_data
(
self
,
afs_path
,
local_path
,
fs_default_name
,
ugi
,
process_num
=
12
):
def
download_data
(
self
,
afs_path
,
local_path
,
fs_default_name
,
ugi
,
process_num
=
12
):
hadoop_home
=
"$HADOOP_HOME"
#hadoop_home = "$HADOOP_HOME"
hadoop_home
=
"~/tools/hadoop-xingtian/hadoop/"
configs
=
{
configs
=
{
"fs.default.name"
:
fs_default_name
,
"fs.default.name"
:
fs_default_name
,
...
@@ -169,8 +169,11 @@ class AsyncExecutor(object):
...
@@ -169,8 +169,11 @@ class AsyncExecutor(object):
self
.
instance
.
get_worker_index
(),
self
.
instance
.
get_worker_index
(),
self
.
instance
.
get_node_cnt
()
/
2
,
self
.
instance
.
get_node_cnt
()
/
2
,
multi_processes
=
process_num
)
multi_processes
=
process_num
)
self
.
instance
.
barrier_all
()
#wait for download_data #TODO only barriere worker
def
config_distributed_nodes
(
self
,
dist_opt
):
def
config_distributed_nodes
(
self
):
self
.
instance
=
ps_instance
.
PaddlePSInstance
(
1
,
2
)
return
self
.
instance
# get total rank
# get total rank
# get rank index
# get rank index
...
@@ -196,11 +199,15 @@ class AsyncExecutor(object):
...
@@ -196,11 +199,15 @@ class AsyncExecutor(object):
self
.
executor
.
gather_servers
(
ips
,
self
.
instance
.
get_node_cnt
())
self
.
executor
.
gather_servers
(
ips
,
self
.
instance
.
get_node_cnt
())
self
.
instance
.
barrier_all
()
#wait all worker start
self
.
instance
.
barrier_all
()
#wait all worker start
self
.
instance
.
barrier_all
()
#wait init model
self
.
instance
.
barrier_all
()
#wait init model
self
.
instance
.
barrier_all
()
#wait for download_data
self
.
instance
.
barrier_all
()
#wait for download_data
#TODO remove this after only barrier worker
self
.
instance
.
barrier_all
()
#wait worker do all things
self
.
instance
.
barrier_all
()
#wait worker do all things
self
.
instance
.
barrier_all
()
#sync
self
.
instance
.
barrier_all
()
#sync
def
init_worker
(
self
,
dist_desc
,
afs_path
,
local_path
,
fs_default_name
,
ugi
):
def
init_worker
(
self
,
dist_desc
,
startup_program
):
place
=
core
.
CPUPlace
()
executor
=
Executor
(
place
)
executor
.
run
(
startup_program
)
self
.
instance
.
barrier_all
()
#wait all server start
self
.
instance
.
barrier_all
()
#wait all server start
ips
=
self
.
instance
.
gather_ips
()
ips
=
self
.
instance
.
gather_ips
()
self
.
executor
.
init_worker
(
dist_desc
,
ips
,
self
.
instance
.
get_node_cnt
(),
self
.
instance
.
_rankid
)
self
.
executor
.
init_worker
(
dist_desc
,
ips
,
self
.
instance
.
get_node_cnt
(),
self
.
instance
.
_rankid
)
...
@@ -208,8 +215,6 @@ class AsyncExecutor(object):
...
@@ -208,8 +215,6 @@ class AsyncExecutor(object):
if
self
.
instance
.
is_first_worker
():
if
self
.
instance
.
is_first_worker
():
self
.
executor
.
init_model
()
self
.
executor
.
init_model
()
self
.
instance
.
barrier_all
()
#wait init model
self
.
instance
.
barrier_all
()
#wait init model
self
.
download_data
(
afs_path
,
local_path
,
fs_default_name
,
ugi
,
process_num
=
12
)
self
.
instance
.
barrier_all
()
#wait for download_data
def
init_model
(
self
):
def
init_model
(
self
):
self
.
executor
.
init_model
()
self
.
executor
.
init_model
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录