Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
83cd1859
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
83cd1859
编写于
8月 21, 2020
作者:
D
Dong Daxiang
提交者:
GitHub
8月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【paddle.fleet】Meta from optimizer (#26392)
* consider the combination of different strategies to work together
上级
bdac6bc8
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
166 addition
and
14 deletion
+166
-14
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+1
-0
python/paddle/distributed/fleet/base/strategy_compiler.py
python/paddle/distributed/fleet/base/strategy_compiler.py
+14
-0
python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
...paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
+6
-1
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
...paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
+1
-0
python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py
...ributed/fleet/meta_optimizers/gradient_merge_optimizer.py
+8
-1
python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py
...ibuted/fleet/meta_optimizers/graph_execution_optimizer.py
+1
-0
python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py
...addle/distributed/fleet/meta_optimizers/lamb_optimizer.py
+2
-1
python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py
...addle/distributed/fleet/meta_optimizers/lars_optimizer.py
+2
-1
python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py
...e/distributed/fleet/meta_optimizers/localsgd_optimizer.py
+1
-0
python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py
.../distributed/fleet/meta_optimizers/meta_optimizer_base.py
+35
-4
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
...e/distributed/fleet/meta_optimizers/pipeline_optimizer.py
+1
-0
python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py
.../distributed/fleet/meta_optimizers/recompute_optimizer.py
+7
-1
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/launch_function_helper.py
...on/paddle/fluid/tests/unittests/launch_function_helper.py
+15
-0
python/paddle/fluid/tests/unittests/test_fleet_graph_execution_meta_optimizer.py
...ts/unittests/test_fleet_graph_execution_meta_optimizer.py
+12
-5
python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer_base.py
...e/fluid/tests/unittests/test_fleet_meta_optimizer_base.py
+58
-0
未找到文件。
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
83cd1859
...
...
@@ -78,6 +78,7 @@ class Fleet(object):
def
init
(
self
,
role_maker
):
self
.
_role_maker
=
role_maker
self
.
strategy_compiler
=
StrategyCompiler
()
return
None
def
is_first_worker
(
self
):
"""
...
...
python/paddle/distributed/fleet/base/strategy_compiler.py
浏览文件 @
83cd1859
...
...
@@ -114,4 +114,18 @@ class StrategyCompiler(StrategyCompilerBase):
0
]
return_graph
=
None
if
graph_optimizers
==
None
else
graph_optimizers
[
0
]
if
meta_optimizers
==
None
or
graph_optimizers
==
None
:
return
return_meta
,
return_graph
# do heuristic filter here, if any meta optimizer in graph optimizers is in
# any meta optimizers' black list, set return_graph to None
need_graph_opt
=
True
for
graph_opt
in
graph_optimizers
:
for
program_opt
in
meta_optimizers
:
if
graph_opt
.
__class__
.
__name__
in
program_opt
.
meta_optimizers_black_list
:
need_graph_opt
=
False
if
not
need_graph_opt
:
return_graph
=
None
return
return_meta
,
return_graph
python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -23,7 +23,12 @@ class AMPOptimizer(MetaOptimizerBase):
self
.
inner_opt
=
optimizer
self
.
amp_opt
=
None
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
"LarsOptimizer"
,
"LambOptimizer"
,
"RecomputeOptimizer"
,
"LocalSGDOptimizer"
,
"GradientMergeOptimizer"
,
"GraphExecutionOptimizer"
]
self
.
meta_optimizers_black_list
=
[
"DGCOptimizer"
]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -25,6 +25,7 @@ class DGCOptimizer(MetaOptimizerBase):
self
.
dgc_opt
=
None
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_black_list
=
[]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -16,13 +16,20 @@ from .meta_optimizer_base import MetaOptimizerBase
__all__
=
[
"GradientMergeOptimizer"
]
# amp + gradient merge + lamb
class
GradientMergeOptimizer
(
MetaOptimizerBase
):
def
__init__
(
self
,
optimizer
):
super
(
GradientMergeOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
wrapped_opt
=
GM
(
optimizer
)
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
"LarsOptimizer"
,
"LambOptimizer"
,
"GraphExecutionOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -25,6 +25,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
self
.
inner_opt
=
optimizer
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_black_list
=
[]
def
_is_graph_out
(
self
):
return
True
...
...
python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -25,7 +25,8 @@ class LambOptimizer(MetaOptimizerBase):
self
.
inner_opt
=
optimizer
self
.
lamb_opt
=
None
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
"GraphExecutionOptimizer"
]
self
.
meta_optimizers_black_list
=
[]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -24,7 +24,8 @@ class LarsOptimizer(MetaOptimizerBase):
self
.
inner_opt
=
optimizer
self
.
lars_opt
=
None
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
"GraphExecutionOptimizer"
]
self
.
meta_optimizers_black_list
=
[]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -25,6 +25,7 @@ class LocalSGDOptimizer(MetaOptimizerBase):
super
(
LocalSGDOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
]
self
.
snapshot_key
=
'@SNAPSHOT'
def
_can_apply
(
self
):
...
...
python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py
浏览文件 @
83cd1859
...
...
@@ -14,10 +14,16 @@
__all__
=
[
"MetaOptimizerBase"
]
from
paddle.fluid.optimizer
import
Optimizer
class
MetaOptimizerBase
(
object
):
class
MetaOptimizerBase
(
Optimizer
):
def
__init__
(
self
,
optimizer
):
pass
self
.
inner_opt
=
optimizer
self
.
_learning_rate
=
self
.
inner_opt
.
_learning_rate
self
.
_learning_rate_map
=
self
.
inner_opt
.
_learning_rate_map
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_black_list
=
[]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
@@ -26,7 +32,7 @@ class MetaOptimizerBase(object):
self
.
user_defined_optimizer
=
user_defined_optimizer
self
.
user_defined_strategy
=
user_defined_strategy
def
_update_inner_optimier
(
self
,
optimizer
):
def
_update_inner_optimi
z
er
(
self
,
optimizer
):
self
.
inner_opt
=
optimizer
def
_can_apply
(
self
):
...
...
@@ -44,12 +50,37 @@ class MetaOptimizerBase(object):
raise
NotImplementedError
(
"you should implement disable strategy in {}"
.
format
(
type
(
self
).
__name__
))
def
apply_gradients
(
self
,
params_grads
):
return
self
.
inner_opt
.
apply_gradients
(
params_grads
=
params_grads
)
def
backward
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
return
self
.
inner_opt
.
backward
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
return
self
.
inner_opt
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
def
minimize_impl
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
raise
NotImplementedError
(
"meta optimizer not implemented"
)
params_grads
=
self
.
backward
(
loss
,
startup_program
=
startup_program
,
parameter_list
=
parameter_list
,
no_grad_set
=
no_grad_set
)
optimize_ops
=
self
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
return
optimize_ops
,
params_grads
def
minimize
(
self
,
loss
,
...
...
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -95,6 +95,7 @@ class PipelineOptimizer(MetaOptimizerBase):
self
.
inner_opt
=
optimizer
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_black_list
=
[]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -24,7 +24,13 @@ class RecomputeOptimizer(MetaOptimizerBase):
self
.
inner_opt
=
optimizer
self
.
wrapped_opt
=
RO
(
optimizer
)
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
"LarsOptimizer"
,
"LambOptimizer"
,
"GradientMergeOptimizer"
,
"GraphExecutionOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
):
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
83cd1859
...
...
@@ -46,6 +46,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_dgc_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_private_function
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base
)
foreach
(
TEST_OP
${
MIXED_DIST_TEST_OPS
}
)
list
(
REMOVE_ITEM TEST_OPS
${
TEST_OP
}
)
endforeach
()
...
...
@@ -399,6 +400,7 @@ if(WITH_DISTRIBUTE)
py_test_modules
(
test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_private_function MODULES test_fleet_private_function ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS
${
dist_ENVS
}
)
if
(
NOT WIN32
)
py_test_modules
(
test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/launch_function_helper.py
浏览文件 @
83cd1859
...
...
@@ -13,6 +13,8 @@
# limitations under the License.
from
multiprocessing
import
Pool
,
Process
import
os
import
socket
from
contextlib
import
closing
def
launch_func
(
func
,
env_dict
):
...
...
@@ -20,3 +22,16 @@ def launch_func(func, env_dict):
os
.
environ
[
key
]
=
env_dict
[
key
]
proc
=
Process
(
target
=
func
)
return
proc
def
_find_free_port
(
port_set
):
def
__free_port
():
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
s
:
s
.
bind
((
''
,
0
))
return
s
.
getsockname
()[
1
]
while
True
:
port
=
__free_port
()
if
port
not
in
port_set
:
port_set
.
add
(
port
)
return
port
python/paddle/fluid/tests/unittests/test_fleet_graph_execution_meta_optimizer.py
浏览文件 @
83cd1859
...
...
@@ -15,7 +15,7 @@
import
unittest
import
paddle
import
os
from
launch_function_helper
import
launch_func
from
launch_function_helper
import
launch_func
,
_find_free_port
class
TestFleetGraphExecutionMetaOptimizer
(
unittest
.
TestCase
):
...
...
@@ -71,20 +71,27 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
proc_b
.
join
()
def
test_graph_execution_optimizer
(
self
):
port_set
=
set
()
port_a
=
_find_free_port
(
port_set
)
port_b
=
_find_free_port
(
port_set
)
node_a
=
{
"PADDLE_TRAINER_ID"
:
"0"
,
"PADDLE_CURRENT_ENDPOINT"
:
"127.0.0.1:
36001"
,
"PADDLE_CURRENT_ENDPOINT"
:
"127.0.0.1:
{}"
.
format
(
port_a
)
,
"PADDLE_TRAINERS_NUM"
:
"2"
,
"PADDLE_TRAINER_ENDPOINTS"
:
"127.0.0.1:36001,127.0.0.1:36002"
,
"PADDLE_TRAINER_ENDPOINTS"
:
"127.0.0.1:{},127.0.0.1:{}"
.
format
(
port_a
,
port_b
),
"http_proxy"
:
""
,
"https_proxy"
:
""
}
node_b
=
{
"PADDLE_TRAINER_ID"
:
"1"
,
"PADDLE_CURRENT_ENDPOINT"
:
"127.0.0.1:
36002"
,
"PADDLE_CURRENT_ENDPOINT"
:
"127.0.0.1:
{}"
.
format
(
port_b
)
,
"PADDLE_TRAINERS_NUM"
:
"2"
,
"PADDLE_TRAINER_ENDPOINTS"
:
"127.0.0.1:36001,127.0.0.1:36002"
,
"PADDLE_TRAINER_ENDPOINTS"
:
"127.0.0.1:{},127.0.0.1:{}"
.
format
(
port_a
,
port_b
),
"http_proxy"
:
""
,
"https_proxy"
:
""
}
...
...
python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer_base.py
0 → 100755
浏览文件 @
83cd1859
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
from
paddle
import
fluid
import
os
import
paddle.distributed.fleet
as
fleet
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
from
paddle.distributed.fleet.meta_optimizers.meta_optimizer_base
import
MetaOptimizerBase
class
TestFleetMetaOptimizerBase
(
unittest
.
TestCase
):
def
net
(
main_prog
,
startup_prog
):
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
fleet
.
init
(
role
)
input_x
=
paddle
.
fluid
.
layers
.
data
(
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
input_y
=
paddle
.
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
fc_1
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
size
=
64
,
act
=
'tanh'
)
fc_2
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
256
,
act
=
'tanh'
)
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
size
=
2
,
act
=
'softmax'
)
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
input_y
)
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.01
)
opt
=
MetaOptimizerBase
(
optimizer
)
opt_ops
,
params_grads
=
opt
.
minimize
(
avg_cost
)
opt
.
apply_optimize
(
avg_cost
,
paddle
.
static
.
default_startup_program
(),
params_grads
)
return
None
net
(
fluid
.
default_startup_program
(),
fluid
.
default_main_program
())
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录