Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3a5f1f22
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a5f1f22
编写于
7月 20, 2021
作者:
Y
Yuang Liu
提交者:
GitHub
7月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybird optim] reduce rend/recv times for recompute, test=develop (#34248)
上级
7f2b5be3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
111 addition
and
11 deletion
+111
-11
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+33
-11
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer_with_recompute.py
...ests/test_fleet_pipeline_meta_optimizer_with_recompute.py
+76
-0
未找到文件。
python/paddle/fluid/optimizer.py
浏览文件 @
3a5f1f22
...
@@ -4867,6 +4867,39 @@ class PipelineOptimizer(object):
...
@@ -4867,6 +4867,39 @@ class PipelineOptimizer(object):
})
})
extra_index_info
[
'index'
]
+=
1
extra_index_info
[
'index'
]
+=
1
elif
self
.
schedule_mode
==
'1F1B'
:
# 1F1B
elif
self
.
schedule_mode
==
'1F1B'
:
# 1F1B
var_shape
=
list
(
var
.
shape
)
var_shape
[
0
]
=
self
.
micro_batch_size
if
var_shape
[
0
]
<
0
else
var_shape
[
0
]
numel
=
np
.
prod
(
var
.
shape
)
assert
numel
%
self
.
mp_degree
==
0
,
\
"The numel={} must be divisible by mp_degree={}"
.
format
(
numel
,
self
.
mp_degree
)
if
'subprog'
in
var
.
name
:
# For recompute, if the checkpoints var is layer_norm_6.tmp_2
# this var will be sent twice, layer_norm_6.tmp_2 for forward pass,
# layer_norm_6.tmp_2.subprog_* for recompute pass.
# We can store the first sent var and copy the value to the
# second one to reduce one send/recv op.
# The origin_ckpt_name is layer_norm_6.tmp_2, which will be used
# to find the stored var for the forward pass.
origin_name
=
var
.
name
.
split
(
'subprog'
)[
0
][
0
:
-
1
]
associate_var
=
block
.
var
(
origin_name
)
block
.
_insert_op_without_sync
(
index
=
index
+
extra_index_info
[
'index'
],
type
=
'assign'
,
inputs
=
{
'X'
:
[
associate_var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'out_shape'
:
var_shape
,
'dtype'
:
var
.
dtype
,
self
.
_op_device_key
:
cur_dev
,
self
.
_op_role_key
:
op_role
,
'use_calc_stream'
:
True
,
})
extra_index_info
[
'index'
]
+=
1
return
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
index
=
index
+
extra_index_info
[
'index'
],
index
=
index
+
extra_index_info
[
'index'
],
type
=
'c_sync_calc_stream'
,
type
=
'c_sync_calc_stream'
,
...
@@ -4894,7 +4927,6 @@ class PipelineOptimizer(object):
...
@@ -4894,7 +4927,6 @@ class PipelineOptimizer(object):
})
})
extra_index_info
[
'index'
]
+=
1
extra_index_info
[
'index'
]
+=
1
insert_index
=
None
insert_index
=
None
if
int
(
op_role
)
==
int
(
self
.
_op_role
.
Backward
):
if
int
(
op_role
)
==
int
(
self
.
_op_role
.
Backward
):
insert_index
=
extra_index_info
[
insert_index
=
extra_index_info
[
'first_optimize_index'
]
'first_optimize_index'
]
...
@@ -4902,7 +4934,6 @@ class PipelineOptimizer(object):
...
@@ -4902,7 +4934,6 @@ class PipelineOptimizer(object):
else
:
else
:
insert_index
=
index
insert_index
=
index
new_op_role
=
self
.
_op_role
.
Backward
new_op_role
=
self
.
_op_role
.
Backward
sync_comm_op
=
block
.
_insert_op_without_sync
(
sync_comm_op
=
block
.
_insert_op_without_sync
(
index
=
insert_index
+
extra_index_info
[
'index'
],
index
=
insert_index
+
extra_index_info
[
'index'
],
type
=
'c_sync_comm_stream'
,
type
=
'c_sync_comm_stream'
,
...
@@ -4913,18 +4944,9 @@ class PipelineOptimizer(object):
...
@@ -4913,18 +4944,9 @@ class PipelineOptimizer(object):
self
.
_op_role_key
:
new_op_role
,
self
.
_op_role_key
:
new_op_role
,
'ring_id'
:
ring_id
,
'ring_id'
:
ring_id
,
})
})
if
int
(
op_role
)
==
int
(
self
.
_op_role
.
Forward
):
if
int
(
op_role
)
==
int
(
self
.
_op_role
.
Forward
):
sync_comm_op
.
_set_attr
(
'pipeline_flag'
,
''
)
sync_comm_op
.
_set_attr
(
'pipeline_flag'
,
''
)
extra_index_info
[
'index'
]
+=
1
extra_index_info
[
'index'
]
+=
1
var_shape
=
list
(
var
.
shape
)
var_shape
[
0
]
=
self
.
micro_batch_size
if
var_shape
[
0
]
<
0
else
var_shape
[
0
]
numel
=
np
.
prod
(
var
.
shape
)
assert
numel
%
self
.
mp_degree
==
0
,
\
"The numel={} must be divisible by mp_degree={}"
.
format
(
numel
,
self
.
mp_degree
)
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
index
=
index
+
extra_index_info
[
'index'
],
index
=
index
+
extra_index_info
[
'index'
],
type
=
'recv_v2'
type
=
'recv_v2'
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
3a5f1f22
...
@@ -17,6 +17,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
...
@@ -17,6 +17,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_transformer
)
list
(
APPEND DIST_TEST_OPS test_parallel_dygraph_transformer
)
list
(
APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer
)
list
(
APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer
)
list
(
APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer_with_recompute
)
list
(
APPEND DIST_TEST_OPS test_fleet_raw_program_meta_optimizer
)
list
(
APPEND DIST_TEST_OPS test_fleet_raw_program_meta_optimizer
)
list
(
APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer
)
list
(
APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer
)
list
(
APPEND DIST_TEST_OPS test_gen_nccl_id_op
)
list
(
APPEND DIST_TEST_OPS test_gen_nccl_id_op
)
...
@@ -56,6 +57,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_2)
...
@@ -56,6 +57,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_2)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_base_3
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_base_3
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer_with_recompute
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_raw_program_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_raw_program_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init
)
...
...
python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer_with_recompute.py
0 → 100644
浏览文件 @
3a5f1f22
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
os
paddle
.
enable_static
()
class
TestFleetMetaOptimizer
(
unittest
.
TestCase
):
def
setUp
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"1"
os
.
environ
[
"PADDLE_TRAINER_ENDPOINTS"
]
=
"127.0.0.1:36001,127.0.0.1:36002"
def
test_pipeline_optimizer
(
self
):
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
fleet
.
init
(
role
)
with
paddle
.
fluid
.
device_guard
(
"gpu:0"
):
input_x
=
paddle
.
fluid
.
layers
.
data
(
name
=
"x"
,
shape
=
[
32
],
dtype
=
'float32'
)
input_y
=
paddle
.
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
fc_1
=
paddle
.
fluid
.
layers
.
fc
(
input
=
input_x
,
size
=
64
,
act
=
'tanh'
)
fc_2
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
64
,
act
=
'tanh'
)
fc_3
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_2
,
size
=
64
,
act
=
'tanh'
)
fc_4
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_3
,
size
=
64
,
act
=
'tanh'
)
fc_5
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_4
,
size
=
64
,
act
=
'tanh'
)
fc_6
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_5
,
size
=
64
,
act
=
'tanh'
)
with
paddle
.
fluid
.
device_guard
(
"gpu:1"
):
fc_7
=
paddle
.
fluid
.
layers
.
fc
(
input
=
fc_6
,
size
=
64
,
act
=
'tanh'
)
prediction
=
paddle
.
fluid
.
layers
.
fc
(
input
=
[
fc_7
],
size
=
2
,
act
=
'softmax'
)
cost
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
input_y
)
avg_cost
=
paddle
.
fluid
.
layers
.
mean
(
x
=
cost
)
strategy
=
paddle
.
distributed
.
fleet
.
DistributedStrategy
()
strategy
.
pipeline
=
True
strategy
.
pipeline_configs
=
{
'micro_batch_size'
:
1
,
'accumulate_steps'
:
2
,
'schedule_mode'
:
'1F1B'
}
checkpoints
=
[
'fc_5.tmp_0'
,
'fc_7.tmp_0'
]
strategy
.
recompute
=
True
strategy
.
recompute_configs
=
{
"checkpoints"
:
checkpoints
,
"enable_offload"
:
False
,
"checkpoint_shape"
:
[]
}
optimizer
=
paddle
.
fluid
.
optimizer
.
Adam
(
0.01
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
avg_cost
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录