Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6b4ca0d7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6b4ca0d7
编写于
9月 04, 2020
作者:
D
danleifeng
提交者:
GitHub
9月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【paddle.fleet】distributed_optimizer supports dygraph (#26541)
paddle.distributed.fleet supports dynamic graph execution.
上级
c8cc0945
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
443 addition
and
3 deletion
+443
-3
python/paddle/distributed/fleet/__init__.py
python/paddle/distributed/fleet/__init__.py
+7
-0
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+345
-0
python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
...on/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
+2
-2
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+47
-1
python/paddle/fluid/tests/unittests/test_fleet_base.py
python/paddle/fluid/tests/unittests/test_fleet_base.py
+26
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
...ddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
+16
-0
未找到文件。
python/paddle/distributed/fleet/__init__.py
浏览文件 @
6b4ca0d7
...
...
@@ -50,3 +50,10 @@ distributed_optimizer = fleet.distributed_optimizer
save_inference_model
=
fleet
.
save_inference_model
save_persistables
=
fleet
.
save_persistables
minimize
=
fleet
.
minimize
distributed_model
=
fleet
.
distributed_model
step
=
fleet
.
step
clear_grad
=
fleet
.
clear_grad
set_lr
=
fleet
.
set_lr
get_lr
=
fleet
.
get_lr
state_dict
=
fleet
.
state_dict
set_state_dict
=
fleet
.
set_state_dict
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
6b4ca0d7
...
...
@@ -15,6 +15,7 @@
from
__future__
import
print_function
import
warnings
import
paddle
from
paddle.fluid.framework
import
dygraph_only
from
paddle.fluid
import
compiler
from
.role_maker
import
UserDefinedRoleMaker
,
PaddleCloudRoleMaker
,
RoleMakerBase
from
.strategy_compiler
import
StrategyCompiler
...
...
@@ -23,6 +24,7 @@ from .meta_optimizer_factory import MetaOptimizerFactory
from
.runtime_factory
import
RuntimeFactory
from
.util_factory
import
UtilFactory
from
paddle.fluid.wrapped_decorator
import
wrap_decorator
from
paddle.fluid.dygraph
import
parallel_helper
def
_inited_runtime_handler_
(
func
):
...
...
@@ -178,6 +180,12 @@ class Fleet(object):
"`role_maker` should be subclass of `RoleMakerBase`, but got {}"
.
format
(
type
(
role_maker
)))
self
.
strategy_compiler
=
StrategyCompiler
()
if
paddle
.
fluid
.
framework
.
in_dygraph_mode
():
if
parallel_helper
.
_is_parallel_ctx_initialized
():
warnings
.
warn
(
"The dygraph parallel environment has been initialized."
)
else
:
paddle
.
distributed
.
init_parallel_env
()
return
None
def
is_first_worker
(
self
):
...
...
@@ -587,12 +595,344 @@ class Fleet(object):
"""
self
.
user_defined_optimizer
=
optimizer
if
paddle
.
fluid
.
framework
.
in_dygraph_mode
():
return
self
if
strategy
==
None
:
strategy
=
DistributedStrategy
()
self
.
user_defined_strategy
=
strategy
self
.
valid_strategy
=
None
return
self
@
dygraph_only
def
distributed_model
(
self
,
model
):
"""
Return dygraph distributed data parallel model (Layer)
Only work in dygraph mode
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize fleet environment
fleet.init(is_collective=True)
# 3. create layer & optimizer
layer = LinearNet()
loss_fn = nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=layer.parameters())
# 4. get data_parallel model using fleet
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
# 5. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
assert
model
is
not
None
self
.
model
=
paddle
.
DataParallel
(
model
)
return
self
.
model
@
dygraph_only
def
state_dict
(
self
):
"""
Get state dict information from optimizer.
Only work in dygraph mode
Returns:
state_dict(dict) : dict contains all the Tensor used by optimizer
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict()
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
state_dict
()
@
dygraph_only
def
set_state_dict
(
self
,
state_dict
):
"""
Load optimizer state dict.
Only work in dygraph mode
Args:
state_dict(dict) : Dict contains all the Tensor needed by optimizer
Returns: None
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict()
paddle.framework.save(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = paddle.framework.load( "paddle_dy")
adam.set_state_dict(opti_state_dict)
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
set_state_dict
(
state_dict
)
@
dygraph_only
def
set_lr
(
self
,
value
):
"""
Set the value of the learning rate manually in the optimizer.
Only work in dygraph mode
Args:
value (float|Tensor): the value of learning rate
Returns: None
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
for i in range(5):
adam.set_lr(lr_list[i])
lr = adam.get_lr()
print("current lr is {}".format(lr))
# Print:
# current lr is 0.2
# current lr is 0.3
# current lr is 0.4
# current lr is 0.5
# current lr is 0.6
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
set_lr
(
value
)
@
dygraph_only
def
get_lr
(
self
):
"""
Get current step learning rate.
Only work in dygraph mode
Returns:
float: The learning rate of the current step.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
lr = adam.get_lr()
print(lr) # 0.01
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
get_lr
()
@
dygraph_only
def
step
(
self
):
"""
Execute the optimizer once.
Only work in dygraph mode
Returns: None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize fleet environment
fleet.init(is_collective=True)
# 3. create layer & optimizer
layer = LinearNet()
loss_fn = nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=layer.parameters())
# 4. get data_parallel model using fleet
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
# 5. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
step
()
@
dygraph_only
def
clear_grad
(
self
):
"""
Execute the optimizer once.
Only work in dygraph mode
Returns: None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize fleet environment
fleet.init(is_collective=True)
# 3. create layer & optimizer
layer = LinearNet()
loss_fn = nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=layer.parameters())
# 4. get data_parallel model using fleet
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
# 5. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
clear_grad
()
def
minimize
(
self
,
loss
,
startup_program
=
None
,
...
...
@@ -642,6 +982,11 @@ class Fleet(object):
# for more examples, please reference https://github.com/PaddlePaddle/FleetX
"""
if
paddle
.
fluid
.
framework
.
in_dygraph_mode
():
# imitate target optimizer retrieval
target_opt
=
self
.
user_defined_optimizer
return
target_opt
.
minimize
(
loss
)
context
=
{}
# cache original feed forward program
self
.
origin_main_program
=
loss
.
block
.
program
...
...
python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
浏览文件 @
6b4ca0d7
...
...
@@ -114,8 +114,8 @@ class TestMnist(TestParallelDyGraphRunnerBase):
model
=
MNIST
()
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
2
,
drop_last
=
True
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
1e-3
,
parameter
_list
=
model
.
parameters
())
opt
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
1e-3
,
parameter
s
=
model
.
parameters
())
return
model
,
train_reader
,
opt
def
run_one_loop
(
self
,
model
,
opt
,
data
):
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
6b4ca0d7
...
...
@@ -488,6 +488,50 @@ class TestParallelDyGraphRunnerBase(object):
model
.
clear_gradients
()
return
out_losses
def
run_gpu_fleet_api_trainer
(
self
,
args
):
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
# 1. enable dygraph
paddle
.
disable_static
()
# 2. init seed
seed
=
90
paddle
.
static
.
default_startup_program
().
random_seed
=
seed
paddle
.
static
.
default_main_program
().
random_seed
=
seed
np
.
random
.
seed
(
seed
)
random
.
seed
=
seed
# get trainer id
args
.
trainer_id
=
paddle
.
distributed
.
get_rank
()
# 3. init parallel env
if
args
.
update_method
==
"nccl2"
:
fleet
.
init
(
is_collective
=
True
)
# 4. train model
model
,
train_reader
,
opt
=
self
.
get_model
()
if
args
.
update_method
==
"nccl2"
:
opt
=
fleet
.
distributed_optimizer
(
opt
)
model
=
fleet
.
distributed_model
(
model
)
out_losses
=
[]
for
step_id
,
data
in
enumerate
(
train_reader
()):
data
=
self
.
_get_data
(
data
,
args
)
if
step_id
==
RUN_STEP
:
break
loss
=
self
.
run_one_loop
(
model
,
opt
,
data
)
out_losses
.
append
(
loss
.
numpy
())
if
args
.
update_method
==
"nccl2"
:
loss
=
model
.
scale_loss
(
loss
)
loss
.
backward
()
if
args
.
update_method
==
"nccl2"
:
model
.
apply_collective_grads
()
opt
.
step
()
opt
.
clear_grad
()
print_to_out
(
out_losses
)
def
runtime_main
(
test_class
):
parser
=
argparse
.
ArgumentParser
(
description
=
'Run dist test.'
)
...
...
@@ -687,7 +731,8 @@ class TestDistBase(unittest.TestCase):
envs
[
'COVERAGE_FILE'
]
=
os
.
getenv
(
'COVERAGE_FILE'
,
''
)
cmd
+=
" -m coverage run --branch -p"
cmd
+=
" %s --role trainer --lr %f"
%
(
model
,
self
.
_lr
)
cmd
+=
" %s --role trainer --update_method local --lr %f"
%
(
model
,
self
.
_lr
)
if
batch_size
!=
DEFAULT_BATCH_SIZE
:
cmd
+=
" --batch_size %d"
%
batch_size
...
...
@@ -850,6 +895,7 @@ class TestDistBase(unittest.TestCase):
if
self
.
__use_cuda
:
tr_cmd
+=
" --use_cuda"
env
.
update
({
"FLAGS_selected_gpus"
:
"{}"
.
format
(
0
),
"CUDA_VISIBLE_DEVICES"
:
"{}"
.
format
(
trainer_id
%
2
),
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
trainer_num
),
"PADDLE_TRAINER_ID"
:
"{}"
.
format
(
trainer_id
),
...
...
python/paddle/fluid/tests/unittests/test_fleet_base.py
浏览文件 @
6b4ca0d7
...
...
@@ -126,6 +126,32 @@ class TestFleetBase(unittest.TestCase):
self
.
assertRaises
(
Exception
,
fleet
.
init_worker
)
class
TestFleetDygraph
(
unittest
.
TestCase
):
def
setUp
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36213,127.0.0.1:36214"
os
.
environ
[
"PADDLE_CURRENT_ENDPOINTS"
]
=
"127.0.0.1:36213"
os
.
environ
[
"PADDLE_TRAINERS_NUM"
]
=
"2"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"0"
def
test_dygraph_method
(
self
):
paddle
.
disable_static
()
value
=
np
.
arange
(
26
).
reshape
(
2
,
13
).
astype
(
"float32"
)
a
=
fluid
.
dygraph
.
to_variable
(
value
)
layer
=
paddle
.
nn
.
Linear
(
13
,
5
)
adam
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.01
,
parameters
=
layer
.
parameters
())
# remove init cause this UT cannot launch distributed task
adam
=
fleet
.
distributed_optimizer
(
adam
)
dp_layer
=
fleet
.
distributed_model
(
layer
)
lr
=
0.001
adam
.
set_lr
(
lr
)
cur_lr
=
adam
.
get_lr
()
assert
(
lr
==
cur_lr
)
state_dict
=
adam
.
state_dict
()
adam
.
set_state_dict
(
state_dict
)
class
TestFleetBaseSingleRunCollective
(
unittest
.
TestCase
):
def
setUp
(
self
):
os
.
environ
.
pop
(
"PADDLE_TRAINER_ENDPOINTS"
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
浏览文件 @
6b4ca0d7
...
...
@@ -47,5 +47,21 @@ class TestParallelDygraphMnistSpawn(TestDistSpawnRunner):
self
.
check_dist_result_with_spawn
(
test_class
=
TestMnist
,
delta
=
1e-5
)
class
TestFleetDygraphMnist
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_nccl2_mode
=
True
self
.
_dygraph
=
True
self
.
_gpu_fleet_api
=
True
def
test_mnist
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
"parallel_dygraph_mnist.py"
,
delta
=
1e-5
,
check_error_log
=
True
,
log_name
=
flag_name
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录