Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8a335b50
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8a335b50
编写于
1月 29, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add downpour device_worker pb configuration
上级
24a80011
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
68 addition
and
3 deletion
+68
-3
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+0
-1
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+37
-0
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+31
-2
未找到文件。
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
8a335b50
...
...
@@ -59,7 +59,6 @@ message TableParameter {
optional
int64
table_id
=
1
;
repeated
string
dense_value_name
=
2
;
repeated
string
dense_grad_name
=
3
;
repeated
int32
dense_table_size
=
4
;
repeated
int32
push_dense_wait_times
=
5
;
// sparse table only
repeated
string
sparse_key_name
=
6
;
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
8a335b50
...
...
@@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2
from
google.protobuf
import
text_format
from
.
import
io
from
.data_feed_desc
import
DataFeedDesc
from
.trainer_desc
import
TrainerDesc
,
MultiTrainer
,
DistMultiTrainer
from
.distributed
import
ps_instance
from
.contrib.utils
import
hdfs_utils
as
hdfs
...
...
@@ -89,6 +90,38 @@ class AsyncExecutor(object):
self
.
executor
=
core
.
AsyncExecutor
(
scope
,
p
)
self
.
instance
=
None
def
run
(
self
,
program
,
data_feed
,
filelist
,
thread_num
,
fetch
,
debug
=
False
):
if
program
is
None
:
program
=
default_main_program
()
program_desc
=
program
.
desc
if
data_feed
is
None
:
raise
ValueError
(
'ValueError: data_feed should be provided'
)
if
filelist
is
None
:
raise
ValueError
(
'ValueError: filelist should be provided'
)
if
isinstance
(
filelist
,
str
):
filelist
=
[
filelist
]
if
not
isinstance
(
thread_num
,
int
):
raise
TypeError
(
'TypeError: thread_num should be a positive number'
)
is_local
=
self
.
instance
==
None
trainer
=
None
if
is_local
:
trainer
=
MultiTrainer
(
data_feed
=
data_feed
,
worker
=
"Hogwild"
)
else
:
trainer
=
DistMultiTrainer
(
data_feed
,
worker
=
"Downpour"
,
fleet_desc
=
self
.
dist_desc
)
# define a trainer and a device_worker here
trainer
.
set_thread
(
thread_num
)
trainer
.
set_filelist
(
filelist
)
trainer
.
set_data_feed
(
data_feed
)
self
.
executor
.
run_from_files
(
program_desc
,
trainer
.
_desc
(),
debug
)
'''
def run(self,
program,
data_feed,
...
...
@@ -160,6 +193,7 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num,
fetch_var_names, mode, debug)
'''
def
download_data
(
self
,
afs_path
,
...
...
@@ -250,6 +284,7 @@ class AsyncExecutor(object):
raise
ValueError
(
'instance is None, please run config_distributed_nodes init instance'
)
self
.
init_desc
=
init_desc
self
.
executor
.
init_server
(
dist_desc
,
self
.
instance
.
_rankid
)
ip
=
self
.
executor
.
start_server
()
self
.
instance
.
set_ip
(
ip
)
...
...
@@ -270,6 +305,8 @@ class AsyncExecutor(object):
raise
ValueError
(
'instance is None, please run config_distributed_nodes init instance'
)
self
.
dist_desc
=
dist_desc
place
=
core
.
CPUPlace
()
executor
=
Executor
(
place
)
executor
.
run
(
startup_program
)
...
...
python/paddle/fluid/trainer_desc.py
浏览文件 @
8a335b50
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
from
paddle.fluid.proto
import
trainer_desc_pb2
import
ps_pb2
as
pslib
from
google.protobuf
import
text_format
__all__
=
[
'TrainerDesc'
,
'MultiTrainer'
,
'DistMultiTrainer'
]
...
...
@@ -42,7 +43,7 @@ class TrainerDesc(object):
class
MultiTrainer
(
TrainerDesc
):
def
__init__
(
self
,
worker
=
"Hogwild"
):
def
__init__
(
self
,
dataset
=
None
,
worker
=
"Hogwild"
):
super
(
MultiTrainer
,
self
).
__init__
()
if
worker
==
"Hogwild"
:
self
.
proto_desc
.
device_worker_name
=
worker
+
"Worker"
...
...
@@ -53,11 +54,39 @@ class MultiTrainer(TrainerDesc):
class
DistMultiTrainer
(
TrainerDesc
):
def
__init__
(
self
,
worker
=
'Downpour'
):
def
__init__
(
self
,
dataset
=
None
,
worker
=
'Downpour'
,
fleet_desc
=
None
):
super
(
DistMultiTrainer
,
self
).
__init__
()
if
worker
==
"Downpour"
:
self
.
proto_desc
.
device_worker_name
=
worker
+
"Worker"
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
self
.
proto_desc
.
data_feed
.
CopyFrom
(
dataset
)
downpour
=
self
.
proto_desc
.
downpour_param
.
add
()
# sparse table should specify:
sparse_table
=
downpour
.
sparse_table
.
add
()
sparse_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
sparse_table
.
table_id
sparse_table
.
sparse_key_name
.
CopyFrom
(
fleet_desc
.
trainer_param
()
.
sparse_table
().
slot_key
())
sparse_table
.
sparse_value_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
sparse_table
().
slot_value
())
sparse_table
.
sparse_grad_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
sparse_table
().
slot_gradient
())
sparse_table
.
emb_dim
=
fleet_desc
.
server_param
.
downpour_server_param
.
downpour_table_param
.
accessor
.
fea_dim
-
2
sparse_table
.
fea_dim
=
downpour
.
emb_dim
+
2
sparse_table
.
label_var_name
=
"click"
# dense table should specify:
dense_table
=
downpour
.
dense_table
.
add
()
dense_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
dense_table
.
table_id
# dense_value_name
dense_table
.
dense_value_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
dense_table
().
dense_variable_name
)
# dense_grad_name
dense_table
.
dense_grad_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
dense_table
().
dense_gradient_name
)
downpour
.
skipped_ops
.
extend
(
fleet_desc
.
trainer_param
.
skip_op
)
print
(
str
(
self
.
proto_desc
))
else
:
raise
ValueError
(
'ValueError: DeviceWorker %s '
'is not supported in DistMultiTrainer'
%
worker
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录