Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e6aacd1e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2310
Star
20933
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e6aacd1e
编写于
7月 30, 2021
作者:
W
wangguanqun
提交者:
GitHub
7月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add trainer desc config to distributed strategy (#34457)
* add trainer desc config to distributed strategy * code style modified
上级
41c4f723
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
136 addition
and
0 deletion
+136
-0
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+8
-0
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+39
-0
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+8
-0
python/paddle/fluid/tests/unittests/test_dist_fleet_trainer_desc_config.py
...id/tests/unittests/test_dist_fleet_trainer_desc_config.py
+68
-0
python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py
.../fluid/tests/unittests/test_fleet_distributed_strategy.py
+13
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
e6aacd1e
...
...
@@ -146,6 +146,13 @@ message AsyncConfig {
optional
int32
use_ps_gpu
=
12
[
default
=
0
];
}
message
TrainerDescConfig
{
optional
string
dump_fields_path
=
1
;
repeated
string
dump_fields
=
2
;
repeated
string
dump_param
=
3
;
repeated
string
stat_var_names
=
4
;
}
message
PipelineConfig
{
optional
int32
micro_batch_size
=
1
[
default
=
1
];
optional
int32
accumulate_steps
=
2
[
default
=
1
];
...
...
@@ -206,6 +213,7 @@ message DistributedStrategy {
optional
ShardingConfig
sharding_configs
=
111
;
optional
HybridConfig
hybrid_configs
=
112
;
optional
TensorParallelConfig
tensor_parallel_configs
=
113
;
optional
TrainerDescConfig
trainer_desc_configs
=
114
;
optional
BuildStrategy
build_strategy
=
201
;
optional
ExecutionStrategy
execution_strategy
=
202
;
optional
GradientScaleConfig
gradient_scale_configs
=
203
;
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
e6aacd1e
...
...
@@ -360,6 +360,45 @@ class DistributedStrategy(object):
"a_sync_configs"
)
assign_configs_value
(
self
.
strategy
.
a_sync_configs
,
configs
)
@
property
def
trainer_desc_configs
(
self
):
"""
Set trainer desc configurations.
**Notes**:
dump_fields_path(str): the path of dump fields
dump_fields(list(str)): the fields that you want to dump
dump_param(list(str)): the param that you want to dump
stat_var_names(list(str)):
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
role_maker = fleet.PaddleCloudRoleMaker()
fleet.init(role_maker)
strategy = fleet.DistributedStrategy()
configs = {"dump_fields_path": "./dump_data", "dump_fields": ["xxx", "yyy"]}
strategy.trainer_desc_configs = configs
# code block for defining loss and local optimizer
# sgd = fleet.distributed_optimizer(optimizer, strategy)
"""
return
get_msg_dict
(
self
.
strategy
.
trainer_desc_configs
)
@
trainer_desc_configs
.
setter
@
is_strict_auto
def
trainer_desc_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
trainer_desc_configs
,
configs
,
"trainer_desc_configs"
)
assign_configs_value
(
self
.
strategy
.
trainer_desc_configs
,
configs
)
@
property
def
amp
(
self
):
"""
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
e6aacd1e
...
...
@@ -1476,6 +1476,14 @@ class Fleet(object):
context
[
"graph_optimize_ops"
]
=
optimize_ops
context
[
"graph_optimize_grads"
]
=
params_grads
program
=
paddle
.
static
.
default_main_program
()
opt_info
=
{}
opt_info
[
"mpi_size"
]
=
self
.
worker_num
()
opt_info
[
"mpi_rank"
]
=
self
.
worker_index
()
for
k
,
v
in
self
.
_user_defined_strategy
.
trainer_desc_configs
.
items
():
opt_info
[
k
]
=
v
program
.
_fleet_opt
=
opt_info
if
self
.
_runtime_handle
is
None
:
self
.
_runtime_handle
=
RuntimeFactory
().
_create_runtime
(
context
)
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_trainer_desc_config.py
0 → 100644
浏览文件 @
e6aacd1e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
time
import
unittest
import
paddle
import
paddle.distributed.fleet.base.role_maker
as
role_maker
import
paddle.fluid.transpiler.details.program_utils
as
pu
paddle
.
enable_static
()
class
TestDistStrategyTrainerDescConfig
(
unittest
.
TestCase
):
def
setUp
(
self
):
os
.
environ
[
"PADDLE_PSERVER_NUMS"
]
=
"2"
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]
=
"2"
os
.
environ
[
"POD_IP"
]
=
"127.0.0.1"
os
.
environ
[
"PADDLE_PORT"
]
=
"36001"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"0"
os
.
environ
[
"PADDLE_PSERVERS_IP_PORT_LIST"
]
=
\
"127.0.0.1:36001,127.0.0.2:36001"
def
test_trainer_desc_config
(
self
):
os
.
environ
[
"TRAINING_ROLE"
]
=
"TRAINER"
import
paddle.distributed.fleet
as
fleet
fleet
.
init
(
role_maker
.
PaddleCloudRoleMaker
())
x
=
paddle
.
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1
],
dtype
=
'float32'
)
y
=
paddle
.
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
paddle
.
fluid
.
layers
.
square_error_cost
(
input
=
x
,
label
=
y
)
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
cost
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
config
=
{
"dump_fields_path"
:
"dump_data"
,
"dump_fields"
:
[
"xxx"
,
"yyy"
],
"dump_param"
:
[]
}
strategy
.
trainer_desc_configs
=
config
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.01
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
program
=
paddle
.
static
.
default_main_program
()
self
.
assertEqual
(
program
.
_fleet_opt
[
"dump_fields_path"
],
"dump_data"
)
self
.
assertEqual
(
len
(
program
.
_fleet_opt
[
"dump_fields"
]),
2
)
self
.
assertEqual
(
len
(
program
.
_fleet_opt
[
"dump_param"
]),
0
)
self
.
assertEqual
(
program
.
_fleet_opt
[
"mpi_size"
],
int
(
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py
浏览文件 @
e6aacd1e
...
...
@@ -255,6 +255,19 @@ class TestStrategyConfig(unittest.TestCase):
strategy
.
a_sync_configs
=
configs
self
.
assertEqual
(
strategy
.
a_sync_configs
[
"k_steps"
],
1000
)
def
test_trainer_desc_configs
(
self
):
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
configs
=
{
"dump_fields_path"
:
"dump_data"
,
"dump_fields"
:
[
"xxx"
,
"yyy"
],
"dump_param"
:
[]
}
strategy
.
trainer_desc_configs
=
configs
self
.
assertEqual
(
strategy
.
trainer_desc_configs
[
"dump_fields_path"
],
"dump_data"
)
self
.
assertEqual
(
len
(
strategy
.
trainer_desc_configs
[
"dump_fields"
]),
2
)
self
.
assertEqual
(
len
(
strategy
.
trainer_desc_configs
[
"dump_param"
]),
0
)
def
test_elastic
(
self
):
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
.
elastic
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录