Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5199c744
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5199c744
编写于
9月 08, 2021
作者:
L
lilong12
提交者:
GitHub
9月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support weight sharing for pipeline (#35351)
* support weight sharing
上级
18a963a5
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
19 addition
and
14 deletion
+19
-14
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+19
-14
未找到文件。
python/paddle/fluid/optimizer.py
浏览文件 @
5199c744
...
@@ -22,7 +22,7 @@ from collections import defaultdict
...
@@ -22,7 +22,7 @@ from collections import defaultdict
import
paddle
import
paddle
from
paddle.fluid.distribute_lookup_table
import
find_distributed_lookup_table
from
paddle.fluid.distribute_lookup_table
import
find_distributed_lookup_table
from
paddle.fluid.framework
import
Program
,
Variable
,
name_scope
,
default_main_program
,
default_startup_program
,
device_guard
from
paddle.fluid.framework
import
Program
,
Variable
,
Parameter
,
name_scope
,
default_main_program
,
default_startup_program
,
device_guard
from
.
import
framework
from
.
import
framework
from
.
import
layers
from
.
import
layers
...
@@ -4234,14 +4234,14 @@ class PipelineOptimizer(object):
...
@@ -4234,14 +4234,14 @@ class PipelineOptimizer(object):
self
.
_device
=
"gpu"
self
.
_device
=
"gpu"
if
framework
.
in_dygraph_mode
():
if
framework
.
in_dygraph_mode
():
raise
Exception
(
"In dygraph, don't support PipelineOptimizer."
)
raise
Exception
(
"In dygraph, don't support PipelineOptimizer."
)
if
not
isinstance
(
optimizer
,
Optimizer
)
and
not
isinstance
(
valid_optimizers
=
(
Optimizer
,
paddle
.
optimizer
.
Optimizer
,
optimizer
,
paddle
.
optimizer
.
Optimizer
)
and
not
isinstance
(
paddle
.
fluid
.
contrib
.
mixed_precision
.
decorator
.
optimizer
,
paddle
.
fluid
.
contrib
.
mixed_precision
.
decorator
.
OptimizerWithMixedPrecision
)
OptimizerWithMixedPrecision
):
if
not
isinstance
(
optimizer
,
valid_optimizers
):
raise
ValueError
(
"The 'optimizer' parameter for "
raise
ValueError
(
"The 'optimizer' parameter for "
"PipelineOptimizer must be an instance of "
"PipelineOptimizer must be an instance of "
"
Optimizer
, but the given type is {}."
.
format
(
"
{}
, but the given type is {}."
.
format
(
type
(
optimizer
)))
valid_optimizers
,
type
(
optimizer
)))
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
# Get the original optimizer defined by users, such as SGD
# Get the original optimizer defined by users, such as SGD
...
@@ -4774,14 +4774,13 @@ class PipelineOptimizer(object):
...
@@ -4774,14 +4774,13 @@ class PipelineOptimizer(object):
# skip data var
# skip data var
if
var
.
is_data
:
continue
if
var
.
is_data
:
continue
prev_device
=
None
prev_device
=
None
generate_ops
=
self
.
output_var_to_op
.
get
(
var_name
)
if
generate_ops
is
None
:
prev_op
=
self
.
_find_prev_op
(
index
,
var_name
)
if
prev_op
is
None
:
if
var_name
not
in
self
.
_param_device_map
:
if
var_name
not
in
self
.
_param_device_map
:
continue
continue
prev_device
=
self
.
_param_device_map
[
var_name
]
prev_device
=
self
.
_param_device_map
[
var_name
]
prev_op
=
self
.
_find_prev_op
(
index
,
var_name
)
if
not
prev_device
:
if
not
prev_device
:
prev_device
=
prev_op
.
attr
(
self
.
_op_device_key
)
\
prev_device
=
prev_op
.
attr
(
self
.
_op_device_key
)
\
if
prev_op
else
None
if
prev_op
else
None
...
@@ -4928,9 +4927,14 @@ class PipelineOptimizer(object):
...
@@ -4928,9 +4927,14 @@ class PipelineOptimizer(object):
self
.
_op_role_key
:
op_role
,
self
.
_op_role_key
:
op_role
,
})
})
extra_index_info
[
'index'
]
+=
1
extra_index_info
[
'index'
]
+=
1
prefix_name
=
var
.
name
.
split
(
'@'
)[
0
]
prefix_var
=
block
.
var
(
prefix_name
)
is_param
=
True
if
isinstance
(
prefix_var
,
Parameter
)
else
False
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
index
=
index
+
extra_index_info
[
'index'
],
index
=
index
+
extra_index_info
[
'index'
],
type
=
'send_v2'
if
not
use_mp
else
'partial_send'
,
type
=
'send_v2'
if
not
use_mp
or
is_param
else
'partial_send'
,
inputs
=
{
'X'
:
var
},
inputs
=
{
'X'
:
var
},
attrs
=
{
attrs
=
{
self
.
_op_device_key
:
prev_dev
,
self
.
_op_device_key
:
prev_dev
,
...
@@ -4966,7 +4970,8 @@ class PipelineOptimizer(object):
...
@@ -4966,7 +4970,8 @@ class PipelineOptimizer(object):
extra_index_info
[
'index'
]
+=
1
extra_index_info
[
'index'
]
+=
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
index
=
index
+
extra_index_info
[
'index'
],
index
=
index
+
extra_index_info
[
'index'
],
type
=
'recv_v2'
if
not
use_mp
else
'partial_recv'
,
type
=
'recv_v2'
if
not
use_mp
or
is_param
else
'partial_recv'
,
outputs
=
{
'Out'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
attrs
=
{
'out_shape'
:
var_shape
,
'out_shape'
:
var_shape
,
...
@@ -4981,7 +4986,7 @@ class PipelineOptimizer(object):
...
@@ -4981,7 +4986,7 @@ class PipelineOptimizer(object):
'id'
:
self
.
mp_rank
,
'id'
:
self
.
mp_rank
,
})
})
extra_index_info
[
'index'
]
+=
1
extra_index_info
[
'index'
]
+=
1
if
use_mp
:
if
use_mp
and
not
is_param
:
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
index
=
index
+
extra_index_info
[
'index'
],
index
=
index
+
extra_index_info
[
'index'
],
type
=
'partial_allgather'
,
type
=
'partial_allgather'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录