Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8d05c00c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8d05c00c
编写于
9月 17, 2020
作者:
D
danleifeng
提交者:
GitHub
9月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix paddle.fleet en-doc for apis in dynamic mode (#27354)
* fix fleet dynamic-mode en-doc;test=develop
上级
746a8ded
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
121 addition
and
114 deletion
+121
-114
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+121
-114
未找到文件。
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
8d05c00c
...
...
@@ -608,25 +608,31 @@ class Fleet(object):
@
dygraph_only
def
distributed_model
(
self
,
model
):
"""
Return dygraph distributed data parallel model (Layer)
Only work in dygraph mode
Return distributed data parallel model (Only work in dygraph mode)
Args:
model (Layer): the user-defind model which inherits Layer.
Returns:
distributed data parallel model which inherits Layer.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
...
...
@@ -658,8 +664,7 @@ class Fleet(object):
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
assert
model
is
not
None
self
.
model
=
paddle
.
DataParallel
(
model
)
...
...
@@ -669,29 +674,30 @@ class Fleet(object):
def
state_dict
(
self
):
"""
Get state dict information from optimizer.
Only work in dygraph mode
(Only work in dygraph mode)
Returns:
state_dict(dict) : dict contains all the Tensor used by optimizer
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict()
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict()
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
state_dict
()
...
...
@@ -700,34 +706,36 @@ class Fleet(object):
def
set_state_dict
(
self
,
state_dict
):
"""
Load optimizer state dict.
Only work in dygraph mode
(Only work in dygraph mode)
Args:
state_dict(dict) : Dict contains all the Tensor needed by optimizer
Returns: None
Returns:
None
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict()
paddle.framework.save(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = paddle.framework.load( "paddle_dy")
adam.set_state_dict(opti_state_dict)
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
state_dict = adam.state_dict()
paddle.framework.save(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = paddle.framework.load( "paddle_dy")
adam.set_state_dict(opti_state_dict)
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
set_state_dict
(
state_dict
)
...
...
@@ -736,42 +744,44 @@ class Fleet(object):
def
set_lr
(
self
,
value
):
"""
Set the value of the learning rate manually in the optimizer.
Only work in dygraph mode
(Only work in dygraph mode)
Args:
value (float|Tensor): the value of learning rate
Returns: None
Returns:
None
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
import numpy as np
import paddle
from paddle.distributed import fleet
value = np.arange(26).reshape(2, 13).astype("float32"
)
a = paddle.fluid.dygraph.to_variable(val
ue)
paddle.disable_static(
)
fleet.init(is_collective=Tr
ue)
layer = paddle.nn.Linear(13, 5
)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()
)
value = np.arange(26).reshape(2, 13).astype("float32"
)
a = paddle.fluid.dygraph.to_variable(value
)
adam = fleet.distributed_optimizer(adam
)
dp_layer = fleet.distributed_model(layer
)
layer = paddle.nn.Linear(13, 5
)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters()
)
lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
for i in range(5):
adam.set_lr(lr_list[i])
lr = adam.get_lr()
print("current lr is {}".format(lr))
# Print:
# current lr is 0.2
# current lr is 0.3
# current lr is 0.4
# current lr is 0.5
# current lr is 0.6
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
for i in range(5):
adam.set_lr(lr_list[i])
lr = adam.get_lr()
print("current lr is {}".format(lr))
# Print:
# current lr is 0.2
# current lr is 0.3
# current lr is 0.4
# current lr is 0.5
# current lr is 0.6
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
set_lr
(
value
)
...
...
@@ -780,31 +790,32 @@ class Fleet(object):
def
get_lr
(
self
):
"""
Get current step learning rate.
Only work in dygraph mode
(Only work in dygraph mode)
Returns:
float: The learning rate of the current step.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
import numpy as np
import paddle
from paddle.distributed import fleet
paddle.disable_static()
fleet.init(is_collective=True)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.fluid.dygraph.to_variable(value)
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
layer = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
adam = fleet.distributed_optimizer(adam)
dp_layer = fleet.distributed_model(layer)
lr = adam.get_lr()
print(lr) # 0.01
lr = adam.get_lr()
print(lr) # 0.01
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
get_lr
()
...
...
@@ -813,27 +824,27 @@ class Fleet(object):
def
step
(
self
):
"""
Execute the optimizer once.
Only work in dygraph mode
(Only work in dygraph mode)
Returns: None
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
...
...
@@ -865,8 +876,6 @@ class Fleet(object):
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
# imitate target optimizer retrieval
...
...
@@ -875,28 +884,28 @@ class Fleet(object):
@
dygraph_only
def
clear_grad
(
self
):
"""
Execute the optimizer once
.
Only work in dygraph mode
Clear the gradients of all optimized parameters for model
.
(Only work in dygraph mode)
Returns: None
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
import paddle
import paddle.nn as nn
from paddle.distributed import fleet
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
...
...
@@ -928,8 +937,6 @@ class Fleet(object):
adam.step()
adam.clear_grad()
if __name__ == '__main__':
paddle.distributed.spawn(train)
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
clear_grad
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录