Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
10ed9e0a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
10ed9e0a
编写于
12月 11, 2018
作者:
H
heqiaozhi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
download & run & instance
上级
57ac412b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
39 addition
and
25 deletion
+39
-25
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+23
-15
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+2
-1
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+14
-9
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
10ed9e0a
...
...
@@ -191,18 +191,19 @@ void AsyncExecutor::SaveModel(const std::string& path) {
}
}
void
AsyncExecutor
::
PrepareDenseThread
()
{
DensePullThreadParam
param
;
param
.
ps_client
=
_pslib_ptr
->
_worker_ptr
;;
param
.
threshold
=
1
;
//GlobalConfig::instance().pull_dense_per_batch; //TODO
param
.
training_thread_num
=
actual_thread_num
;
param
.
root_scope
=
root_scope_
;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param
.
dense_params
=
&
_param_config
.
dense_variable_name
;
_pull_dense_thread
=
std
::
shared_ptr
<
DensePullThread
>
(
new
DensePullThread
(
param
));
_pull_dense_thread
->
start
();
void
AsyncExecutor
::
PrepareDenseThread
(
const
std
::
string
&
mode
)
{
if
(
mode
==
"mpi"
)
{
DensePullThreadParam
param
;
param
.
ps_client
=
_pslib_ptr
->
_worker_ptr
;;
param
.
threshold
=
1
;
//GlobalConfig::instance().pull_dense_per_batch; //TODO
param
.
training_thread_num
=
actual_thread_num
;
param
.
root_scope
=
root_scope_
;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param
.
dense_params
=
&
_param_config
.
dense_variable_name
;
_pull_dense_thread
=
std
::
shared_ptr
<
DensePullThread
>
(
new
DensePullThread
(
param
));
_pull_dense_thread
->
start
();
}
}
void
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
...
...
@@ -210,6 +211,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
const
std
::
string
&
mode
,
const
bool
debug
)
{
std
::
vector
<
std
::
thread
>
threads
;
...
...
@@ -251,11 +253,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>>
readers
;
PrepareReaders
(
readers
,
actual_thread_num
,
data_feed_desc
,
filelist
);
PrepareDenseThread
();
PrepareDenseThread
(
mode
);
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>>
workers
;
workers
.
resize
(
actual_thread_num
);
for
(
auto
&
worker
:
workers
)
{
worker
.
reset
(
new
AsyncExecutorThreadWorker
);
if
(
mode
==
"mpi"
)
{
worker
.
reset
(
new
AsyncExecutorThreadWorker
);
}
else
{
worker
.
reset
(
new
ExecutorThreadWorker
);
}
}
// prepare thread resource here
...
...
@@ -274,7 +280,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for
(
auto
&
th
:
threads
)
{
th
.
join
();
}
_pull_dense_thread
->
stop
();
if
(
mode
==
"mpi"
)
{
_pull_dense_thread
->
stop
();
}
root_scope_
->
DropKids
();
return
;
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
10ed9e0a
...
...
@@ -61,6 +61,7 @@ class AsyncExecutor {
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
fetch_names
,
const
std
::
string
&
mode
,
const
bool
debug
=
false
);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
);
...
...
@@ -79,7 +80,7 @@ class AsyncExecutor {
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
*
root_scope
,
const
int
thread_index
,
const
bool
debug
);
void
PrepareDenseThread
();
void
PrepareDenseThread
(
const
std
::
string
&
mode
);
public:
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
_pslib_ptr
;
std
::
shared_ptr
<
DensePullThread
>
_pull_dense_thread
;
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
10ed9e0a
...
...
@@ -87,9 +87,8 @@ class AsyncExecutor(object):
scope
=
global_scope
()
self
.
executor
=
core
.
AsyncExecutor
(
scope
,
p
)
self
.
instance
=
ps_instance
.
PaddlePSInstance
(
1
,
2
)
def
run
(
self
,
program
,
data_feed
,
filelist
,
thread_num
,
fetch
,
debug
=
False
):
def
run
(
self
,
program
,
data_feed
,
filelist
,
thread_num
,
fetch
,
mode
=
""
,
debug
=
False
):
"""
Run program by this AsyncExecutor. Training dataset will be in filelist.
Users can also inspect certain variables by naming them in parameter
...
...
@@ -151,10 +150,11 @@ class AsyncExecutor(object):
self
.
executor
.
run_from_files
(
program_desc
,
data_feed
.
desc
(),
filelist
,
thread_num
,
fetch_var_names
,
debug
)
fetch_var_names
,
mode
,
debug
)
def
download_data
(
self
,
afs_path
,
local_path
,
fs_default_name
,
ugi
,
process_num
=
12
):
hadoop_home
=
"$HADOOP_HOME"
#hadoop_home = "$HADOOP_HOME"
hadoop_home
=
"~/tools/hadoop-xingtian/hadoop/"
configs
=
{
"fs.default.name"
:
fs_default_name
,
...
...
@@ -169,8 +169,11 @@ class AsyncExecutor(object):
self
.
instance
.
get_worker_index
(),
self
.
instance
.
get_node_cnt
()
/
2
,
multi_processes
=
process_num
)
self
.
instance
.
barrier_all
()
#wait for download_data #TODO only barriere worker
def
config_distributed_nodes
(
self
,
dist_opt
):
def
config_distributed_nodes
(
self
):
self
.
instance
=
ps_instance
.
PaddlePSInstance
(
1
,
2
)
return
self
.
instance
# get total rank
# get rank index
...
...
@@ -196,11 +199,15 @@ class AsyncExecutor(object):
self
.
executor
.
gather_servers
(
ips
,
self
.
instance
.
get_node_cnt
())
self
.
instance
.
barrier_all
()
#wait all worker start
self
.
instance
.
barrier_all
()
#wait init model
self
.
instance
.
barrier_all
()
#wait for download_data
self
.
instance
.
barrier_all
()
#wait for download_data
#TODO remove this after only barrier worker
self
.
instance
.
barrier_all
()
#wait worker do all things
self
.
instance
.
barrier_all
()
#sync
def
init_worker
(
self
,
dist_desc
,
afs_path
,
local_path
,
fs_default_name
,
ugi
):
def
init_worker
(
self
,
dist_desc
,
startup_program
):
place
=
core
.
CPUPlace
()
executor
=
Executor
(
place
)
executor
.
run
(
startup_program
)
self
.
instance
.
barrier_all
()
#wait all server start
ips
=
self
.
instance
.
gather_ips
()
self
.
executor
.
init_worker
(
dist_desc
,
ips
,
self
.
instance
.
get_node_cnt
(),
self
.
instance
.
_rankid
)
...
...
@@ -208,8 +215,6 @@ class AsyncExecutor(object):
if
self
.
instance
.
is_first_worker
():
self
.
executor
.
init_model
()
self
.
instance
.
barrier_all
()
#wait init model
self
.
download_data
(
afs_path
,
local_path
,
fs_default_name
,
ugi
,
process_num
=
12
)
self
.
instance
.
barrier_all
()
#wait for download_data
def
init_model
(
self
):
self
.
executor
.
init_model
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录