Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2ae32f0b
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ae32f0b
编写于
8月 09, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revert the change of api
上级
1b690213
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
10 addition
and
13 deletion
+10
-13
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+9
-11
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+1
-2
未找到文件。
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
2ae32f0b
...
@@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase):
...
@@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase):
self
.
origin_prog
=
main
.
clone
()
self
.
origin_prog
=
main
.
clone
()
return
main
return
main
def
get_trainer
(
self
,
config
=
None
):
def
get_trainer
(
self
,
config
=
None
,
sync_mode
=
True
):
t
=
self
.
_transpiler_instance
(
config
)
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
)
return
t
.
get_trainer_program
()
return
t
.
get_trainer_program
()
def
get_pserver
(
self
,
ep
,
config
=
None
):
def
get_pserver
(
self
,
ep
,
config
=
None
,
sync_mode
=
True
):
t
=
self
.
_transpiler_instance
(
config
)
t
=
self
.
_transpiler_instance
(
config
,
sync_mode
)
pserver
=
t
.
get_pserver_program
(
ep
)
pserver
=
t
.
get_pserver_program
(
ep
)
startup
=
t
.
get_startup_program
(
ep
,
pserver
)
startup
=
t
.
get_startup_program
(
ep
,
pserver
)
return
pserver
,
startup
return
pserver
,
startup
def
_transpiler_instance
(
self
,
config
=
None
):
def
_transpiler_instance
(
self
,
config
=
None
,
sync_mode
=
True
):
if
not
self
.
transpiler
:
if
not
self
.
transpiler
:
main
=
self
.
get_main_program
()
main
=
self
.
get_main_program
()
self
.
transpiler
=
fluid
.
DistributeTranspiler
(
config
=
config
)
self
.
transpiler
=
fluid
.
DistributeTranspiler
(
config
=
config
)
...
@@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase):
...
@@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase):
self
.
trainer_id
,
self
.
trainer_id
,
program
=
main
,
program
=
main
,
pservers
=
self
.
pserver_eps
,
pservers
=
self
.
pserver_eps
,
trainers
=
self
.
trainers
)
trainers
=
self
.
trainers
,
sync_mode
=
sync_mode
)
return
self
.
transpiler
return
self
.
transpiler
...
@@ -470,8 +471,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
...
@@ -470,8 +471,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
def
transpiler_test_impl
(
self
):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
False
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
False
)
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
)
self
.
assertEqual
(
len
(
pserver1
.
blocks
),
3
)
self
.
assertEqual
(
len
(
pserver1
.
blocks
),
3
)
# 0 listen_and_serv
# 0 listen_and_serv
...
@@ -503,9 +503,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
...
@@ -503,9 +503,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
def
transpiler_test_impl
(
self
):
def
transpiler_test_impl
(
self
):
config
=
fluid
.
DistributeTranspilerConfig
()
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
sync_mode
=
False
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
)
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
,
config
,
False
)
self
.
assertEqual
(
len
(
pserver1
.
blocks
),
6
)
self
.
assertEqual
(
len
(
pserver1
.
blocks
),
6
)
# 0 listen_and_serv
# 0 listen_and_serv
...
@@ -525,7 +524,6 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
...
@@ -525,7 +524,6 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
trainer
=
self
.
get_trainer
(
config
)
trainer
=
self
.
get_trainer
(
config
)
self
.
assertEqual
(
len
(
trainer
.
blocks
),
1
)
self
.
assertEqual
(
len
(
trainer
.
blocks
),
1
)
print
([
op
.
type
for
op
in
trainer
.
blocks
[
0
].
ops
])
ops
=
[
ops
=
[
'split_ids'
,
'prefetch'
,
'merge_ids'
,
'sequence_pool'
,
'split_ids'
,
'split_ids'
,
'prefetch'
,
'merge_ids'
,
'sequence_pool'
,
'split_ids'
,
'prefetch'
,
'merge_ids'
,
'sequence_pool'
,
'concat'
,
'mul'
,
'prefetch'
,
'merge_ids'
,
'sequence_pool'
,
'concat'
,
'mul'
,
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
2ae32f0b
...
@@ -124,7 +124,6 @@ class DistributeTranspilerConfig(object):
...
@@ -124,7 +124,6 @@ class DistributeTranspilerConfig(object):
slice_var_up
=
True
slice_var_up
=
True
split_method
=
None
split_method
=
None
min_block_size
=
8192
min_block_size
=
8192
sync_mode
=
True
class
DistributeTranspiler
(
object
):
class
DistributeTranspiler
(
object
):
...
@@ -198,7 +197,7 @@ class DistributeTranspiler(object):
...
@@ -198,7 +197,7 @@ class DistributeTranspiler(object):
program
=
default_main_program
()
program
=
default_main_program
()
self
.
origin_program
=
program
self
.
origin_program
=
program
self
.
trainer_num
=
trainers
self
.
trainer_num
=
trainers
self
.
sync_mode
=
sync_mode
and
self
.
config
.
sync_mode
self
.
sync_mode
=
sync_mode
self
.
trainer_id
=
trainer_id
self
.
trainer_id
=
trainer_id
pserver_endpoints
=
pservers
.
split
(
","
)
pserver_endpoints
=
pservers
.
split
(
","
)
self
.
pserver_endpoints
=
pserver_endpoints
self
.
pserver_endpoints
=
pserver_endpoints
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录