Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4adac0e3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4adac0e3
编写于
8月 05, 2020
作者:
D
Dong Daxiang
提交者:
GitHub
8月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【paddle.fleet】Add fleet base context (#25954)
* generate context during compile
上级
358bc06c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
27 addition
and
20 deletion
+27
-20
python/paddle/fleet/base/fleet_base.py
python/paddle/fleet/base/fleet_base.py
+15
-4
python/paddle/fleet/base/runtime_factory.py
python/paddle/fleet/base/runtime_factory.py
+3
-5
python/paddle/fleet/base/util_factory.py
python/paddle/fleet/base/util_factory.py
+3
-4
python/paddle/fleet/runtime/runtime_base.py
python/paddle/fleet/runtime/runtime_base.py
+2
-5
python/paddle/fluid/tests/unittests/test_fleet_util.py
python/paddle/fluid/tests/unittests/test_fleet_util.py
+4
-2
未找到文件。
python/paddle/fleet/base/fleet_base.py
浏览文件 @
4adac0e3
...
...
@@ -279,8 +279,11 @@ class Fleet(object):
# for more examples, please reference https://github.com/PaddlePaddle/Fleet
"""
context
=
{}
# cache original feed forward program
self
.
origin_main_program
=
loss
.
block
.
program
context
[
"origin_main_program"
]
=
self
.
origin_main_program
context
[
"loss"
]
=
loss
if
startup_program
==
None
:
self
.
origin_startup_program
=
\
paddle
.
default_startup_program
().
clone
(
for_test
=
False
)
...
...
@@ -288,6 +291,8 @@ class Fleet(object):
else
:
self
.
origin_startup_program
=
\
startup_program
.
clone
(
for_test
=
False
)
context
[
"origin_startup_program"
]
=
startup_program
context
[
"role_maker"
]
=
self
.
_role_maker
# compile time
distributed_optimizer_list
=
\
...
...
@@ -317,6 +322,9 @@ class Fleet(object):
valid_strategy
=
self
.
strategy_compiler
.
_get_valid_strategy
(
self
.
user_defined_strategy
,
can_not_apply_optimizer_list
)
context
[
"valid_strategy"
]
=
valid_strategy
self
.
valid_strategy
=
valid_strategy
optimize_ops
=
[]
...
...
@@ -334,6 +342,8 @@ class Fleet(object):
parameter_list
=
parameter_list
,
no_grad_set
=
no_grad_set
)
context
[
"program_optimize_ops"
]
=
optimize_ops
context
[
"program_params_grads"
]
=
params_grads
if
graph_optimizer
:
optimize_ops
,
params_grads
=
graph_optimizer
.
minimize
(
loss
,
...
...
@@ -344,12 +354,13 @@ class Fleet(object):
# if a graph optimizer takes effect, mostly
# optimizers_ops and params_grads are None
# i.e. users can not modify current computation graph anymore
context
[
"graph_optimize_ops"
]
=
optimize_ops
context
[
"graph_optimize_grads"
]
=
params_grads
if
self
.
_runtime_handle
is
None
:
self
.
_runtime_handle
=
RuntimeFactory
().
_create_runtime
(
valid_strategy
,
self
.
_role_maker
,
optimize_ops
,
params_grads
)
self
.
_runtime_handle
=
RuntimeFactory
().
_create_runtime
(
context
)
if
self
.
_util
is
None
:
self
.
_util
=
UtilFactory
().
_create_util
(
valid_strategy
,
self
.
_role_maker
,
optimize_ops
,
params_grads
)
self
.
_util
=
UtilFactory
().
_create_util
(
context
)
return
optimize_ops
,
params_grads
python/paddle/fleet/base/runtime_factory.py
浏览文件 @
4adac0e3
...
...
@@ -18,10 +18,8 @@ class RuntimeFactory(object):
def
__init__
(
self
):
pass
def
_create_runtime
(
self
,
final_dist_strategy
,
role_maker
,
opt_ops
,
params_grads
):
if
role_maker
.
_is_collective
:
def
_create_runtime
(
self
,
context
):
if
context
[
"role_maker"
].
_is_collective
:
collective_runtime
=
CollectiveRuntime
()
collective_runtime
.
_set_basic_info
(
final_dist_strategy
,
role_maker
,
opt_ops
,
params_grads
)
collective_runtime
.
_set_basic_info
(
context
)
return
collective_runtime
python/paddle/fleet/base/util_factory.py
浏览文件 @
4adac0e3
...
...
@@ -20,11 +20,10 @@ __all__ = ['UtilBase']
class
UtilFactory
(
object
):
def
_create_util
(
self
,
dist_strategy
,
role_maker
,
optimize_ops
,
params_grads
):
def
_create_util
(
self
,
context
):
util
=
UtilBase
()
util
.
_set_strategy
(
dist_strategy
)
util
.
_set_role_maker
(
role_maker
)
util
.
_set_strategy
(
context
[
"valid_strategy"
]
)
util
.
_set_role_maker
(
context
[
"role_maker"
]
)
return
util
...
...
python/paddle/fleet/runtime/runtime_base.py
浏览文件 @
4adac0e3
...
...
@@ -19,11 +19,8 @@ class RuntimeBase(object):
def
__init__
(
self
):
pass
def
_set_basic_info
(
self
,
loss
,
role_maker
,
optimizer
,
strategy
):
self
.
loss
=
loss
self
.
role_maker
=
role_maker
self
.
optimizer
=
optimizer
self
.
strategy
=
strategy
def
_set_basic_info
(
self
,
context
):
self
.
context
=
context
def
_run_worker
(
self
):
pass
...
...
python/paddle/fluid/tests/unittests/test_fleet_util.py
浏览文件 @
4adac0e3
...
...
@@ -33,8 +33,10 @@ class TestFleetUtil(unittest.TestCase):
role_maker
=
None
# should be fleet.PaddleCloudRoleMaker()
optimize_ops
=
[]
params_grads
=
[]
util
=
factory
.
_create_util
(
strategy
,
role_maker
,
optimize_ops
,
params_grads
)
context
=
{}
context
[
"role_maker"
]
=
role_maker
context
[
"valid_strategy"
]
=
strategy
util
=
factory
.
_create_util
(
context
)
self
.
assertEqual
(
util
.
role_maker
,
None
)
def
test_get_util
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录