Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7eae6570
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7eae6570
编写于
4月 21, 2022
作者:
S
ShenLiang
提交者:
GitHub
4月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix pipeline in new dygraph (#41937) (#42053)
* fix utest * fix time
上级
50fd2450
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
240 addition
and
54 deletion
+240
-54
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+41
-9
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+30
-27
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+4
-3
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
.../paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
+140
-4
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-1
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py
...on/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py
+5
-4
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
...n/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
+5
-5
python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py
...ests/unittests/test_parallel_dygraph_pipeline_parallel.py
+13
-0
python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py
.../tests/unittests/test_parallel_dygraph_tensor_parallel.py
+1
-1
未找到文件。
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
7eae6570
...
...
@@ -12,6 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# The file has been adapted from the file:
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/module.py
# Git commit hash: fafc827d643b3eed611e282d909025f16be36601
# We retain the following license from the original files:
# MIT License
# Copyright (c) Microsoft Corporation.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
import
math
import
re
import
glob
...
...
@@ -24,6 +50,7 @@ import paddle
from
paddle.fluid.dygraph.layers
import
Layer
from
...utils.log_util
import
logger
,
layer_to_str
from
..pp_utils.utils
import
_hp_recompute
,
_initialize_recompute_setting
from
paddle.fluid.framework
import
in_dygraph_mode
__all__
=
[]
...
...
@@ -269,15 +296,20 @@ class PipelineLayer(Layer):
for
key
,
comm
in
self
.
shared_comm
.
items
():
param
=
getattr
(
self
.
shared_layers
[
key
],
comm
[
'weight_attr'
])
# need use trace_op to allreduce weight
with
paddle
.
framework
.
no_grad
():
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
type
=
"c_allreduce_sum"
,
inputs
=
{
'X'
:
param
.
_grad_ivar
()},
outputs
=
{
'Out'
:
param
.
_grad_ivar
()},
attrs
=
{
'ring_id'
:
comm
[
'group'
].
id
,
'use_calc_stream'
:
True
})
if
in_dygraph_mode
():
with
paddle
.
framework
.
no_grad
():
paddle
.
distributed
.
all_reduce
(
param
.
grad
,
group
=
comm
[
'group'
])
else
:
with
paddle
.
framework
.
no_grad
():
paddle
.
fluid
.
framework
.
_dygraph_tracer
().
trace_op
(
type
=
"c_allreduce_sum"
,
inputs
=
{
'X'
:
param
.
_grad_ivar
()},
outputs
=
{
'Out'
:
param
.
_grad_ivar
()},
attrs
=
{
'ring_id'
:
comm
[
'group'
].
id
,
'use_calc_stream'
:
True
})
def
_segment_network
(
self
,
seg_method
):
logger
.
info
(
"start segment network.."
)
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
7eae6570
...
...
@@ -23,6 +23,7 @@ from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
from
..utils.log_util
import
logger
from
..meta_optimizers.dygraph_optimizer
import
HybridParallelOptimizer
,
HybridParallelGradScaler
from
.pp_utils
import
p2p_communication
as
p2p
import
paddle.fluid.core
as
core
__all__
=
[]
...
...
@@ -238,9 +239,9 @@ class PipelineParallel(MetaParallelBase):
assert
self
.
_layers
.
_loss_fn
is
not
None
,
"loss function should exist to compute loss"
labels
=
self
.
_load_micro_batch
(
self
.
micro_batch_id
)
output_tensor
=
self
.
_layers
.
_loss_fn
(
output_tensor
,
labels
)
assert
isinstance
(
output_tensor
,
paddle
.
Tensor
),
"Currently, loss_fn should obtain Paddle.Tensor dtype"
assert
isinstance
(
output_tensor
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)
)
,
"Currently, loss_fn should obtain Paddle.Tensor dtype"
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
if
self
.
accumulate_steps
>
1
:
...
...
@@ -254,31 +255,33 @@ class PipelineParallel(MetaParallelBase):
return
output_tensor
def
_backward_step
(
self
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
if
self
.
is_last_stage
:
assert
output_tensor_grad
is
None
if
self
.
scaler
:
paddle
.
autograd
.
backward
(
self
.
scaler
.
scale
(
output_tensor
))
else
:
paddle
.
autograd
.
backward
(
output_tensor
)
else
:
if
isinstance
(
output_tensor
,
tuple
):
outputs
=
[
t
for
t
in
output_tensor
if
not
t
.
stop_gradient
]
assert
len
(
outputs
)
==
len
(
output_tensor_grad
)
paddle
.
autograd
.
backward
(
tensors
=
outputs
,
grad_tensors
=
[
t
for
t
in
output_tensor_grad
])
else
:
paddle
.
autograd
.
backward
(
tensors
=
[
output_tensor
],
grad_tensors
=
[
output_tensor_grad
])
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
if
isinstance
(
input_tensor
,
tuple
):
input_tensor_grad
=
tuple
(
[
t
.
grad
for
t
in
input_tensor
if
not
t
.
stop_gradient
])
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
if
self
.
is_last_stage
:
assert
output_tensor_grad
is
None
if
self
.
scaler
:
paddle
.
autograd
.
backward
(
self
.
scaler
.
scale
(
output_tensor
))
else
:
paddle
.
autograd
.
backward
(
output_tensor
)
else
:
input_tensor_grad
=
input_tensor
.
grad
return
input_tensor_grad
if
isinstance
(
output_tensor
,
tuple
):
outputs
=
[
t
for
t
in
output_tensor
if
not
t
.
stop_gradient
]
assert
len
(
outputs
)
==
len
(
output_tensor_grad
)
paddle
.
autograd
.
backward
(
tensors
=
outputs
,
grad_tensors
=
[
t
for
t
in
output_tensor_grad
])
else
:
paddle
.
autograd
.
backward
(
tensors
=
[
output_tensor
],
grad_tensors
=
[
output_tensor_grad
])
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
if
isinstance
(
input_tensor
,
tuple
):
input_tensor_grad
=
tuple
(
[
t
.
grad
for
t
in
input_tensor
if
not
t
.
stop_gradient
])
else
:
input_tensor_grad
=
input_tensor
.
grad
return
input_tensor_grad
def
_load_micro_batch
(
self
,
cache_id
):
inputs
=
self
.
data
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
7eae6570
...
...
@@ -17,6 +17,7 @@ from .utils import paddle_2_number, number_2_dtype
from
...utils.log_util
import
logger
import
numpy
as
np
from
paddle
import
_C_ops
import
paddle.fluid.core
as
core
_hcg
=
None
_use_cache
=
False
...
...
@@ -114,7 +115,7 @@ class SendRecvMeta:
paddle
.
distributed
.
send
(
stop_grad
,
dst
=
1
,
group
=
group
)
def
send_meta
(
self
,
tensor
,
group
):
if
isinstance
(
tensor
,
paddle
.
Tensor
):
if
isinstance
(
tensor
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)
):
tensor_type
=
paddle
.
to_tensor
([
0
])
# send tensor type
paddle
.
distributed
.
send
(
tensor_type
,
dst
=
1
,
group
=
group
)
...
...
@@ -129,11 +130,11 @@ class SendRecvMeta:
paddle
.
distributed
.
send
(
nums
,
dst
=
1
,
group
=
group
)
for
d
in
tensor
:
assert
isinstance
(
d
,
paddle
.
Tensor
)
assert
isinstance
(
d
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)
)
self
.
_send_dims_shape_dtype
(
d
,
group
=
group
)
def
set_send_message
(
self
,
tensor
):
if
isinstance
(
tensor
,
paddle
.
Tensor
):
if
isinstance
(
tensor
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)
):
self
.
send_shape_message
=
tensor
.
shape
self
.
send_dtype_message
=
paddle_2_number
(
tensor
.
dtype
)
elif
isinstance
(
tensor
,
tuple
):
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
浏览文件 @
7eae6570
...
...
@@ -17,10 +17,11 @@ import contextlib
import
paddle
from
paddle.fluid
import
core
from
paddle
import
_C_ops
from
paddle.autograd
import
PyLayer
from
paddle.autograd
import
PyLayer
,
EagerPyLayer
from
paddle.fluid
import
framework
from
...utils.recompute
import
check_recompute_necessary
,
detach_variable
from
..parallel_layers.random
import
get_rng_state_tracker
from
paddle.fluid.framework
import
in_dygraph_mode
__all__
=
[]
...
...
@@ -164,6 +165,138 @@ def _swith_rng_state_tracker(rng_state, tracker):
get_rng_state_tracker
().
set_states_tracker
(
orig_cuda_rng_tracker
)
class
_HPEagerRecomputeFunction
(
EagerPyLayer
):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
2. Offload support for activation
3. Support MP segmentation of activation to further reduce cuda memory
4. Adapt to the random state of MP
"""
@
staticmethod
def
forward
(
ctx
,
run_function
,
all_outputs
,
*
args
):
check_recompute_necessary
(
args
)
# store for recomputing
ctx
.
run_function
=
run_function
# store the rng states
ctx
.
fwd_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
ctx
.
fwd_cuda_rng_state_tracker
=
get_rng_state_tracker
(
).
get_states_tracker
()
# save input for backward
ctx
.
inputs
=
[]
ctx
.
tensor_indices
=
[]
ctx
.
tensor_shapes
=
[]
tensor_inputs
=
[]
cur_device
=
paddle
.
get_device
()
assert
'gpu:'
in
paddle
.
get_device
(
),
"Recompute with RNG is not support current device: {}."
.
format
(
cur_device
)
# TODO support AMP
tracer
=
framework
.
_dygraph_tracer
()
ctx
.
is_fw_autocast
=
False
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O0
else
True
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O2
:
ctx
.
amp_level
=
'O2'
elif
tracer
.
_amp_level
in
(
core
.
AmpLevel
.
O1
,
core
.
AmpLevel
.
O0
):
ctx
.
amp_level
=
'O1'
else
:
raise
ValueError
(
"unsupported amp level: {}"
.
format
(
tracer
.
_amp_level
))
ctx
.
amp_white_list
,
ctx
.
amp_black_list
=
tracer
.
_get_amp_op_list
()
with
paddle
.
no_grad
():
outputs
=
run_function
(
*
args
)
for
i
,
arg
in
enumerate
(
args
):
if
paddle
.
is_tensor
(
arg
):
state
=
arg
.
stop_gradient
if
_recompute_partition
:
ctx
.
tensor_shapes
.
append
(
arg
.
shape
)
partition
=
_split_activation
(
arg
.
detach
()).
clone
()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg
=
partition
.
cpu
()
if
_recompute_offload
else
partition
else
:
arg
=
arg
.
cpu
()
if
_recompute_offload
else
arg
arg
.
stop_gradient
=
state
tensor_inputs
.
append
(
arg
)
ctx
.
tensor_indices
.
append
(
i
)
ctx
.
inputs
.
append
(
None
)
else
:
ctx
.
inputs
.
append
(
arg
)
ctx
.
save_for_backward
(
*
tensor_inputs
)
if
paddle
.
is_tensor
(
outputs
):
all_outputs
+=
[
outputs
]
return
outputs
else
:
all_outputs
+=
outputs
return
tuple
(
outputs
)
@
staticmethod
def
backward
(
ctx
,
*
args
):
with
paddle
.
fluid
.
dygraph
.
guard
():
# Restore inputs
inputs
=
list
(
ctx
.
inputs
)
tensor_indices
=
ctx
.
tensor_indices
tensor_shapes
=
ctx
.
tensor_shapes
tensors
=
list
(
ctx
.
saved_tensor
())
device_id
=
paddle
.
distributed
.
ParallelEnv
().
device_id
for
i
,
idx
in
enumerate
(
tensor_indices
):
if
_recompute_partition
:
state
=
tensors
[
i
].
stop_gradient
tensors
[
i
]
=
_merge_activation
(
tensors
[
i
]).
detach
(
).
reshape_
(
tensor_shapes
[
i
])
tensors
[
i
].
stop_gradient
=
state
inputs
[
idx
]
=
tensors
[
i
].
cuda
(
device_id
)
if
_recompute_offload
else
tensors
[
i
]
tracer
=
framework
.
_dygraph_tracer
()
tracer
.
_has_grad
=
True
# need restore auto_cast state as well as w/b list
with
_swith_rng_state_tracker
(
ctx
.
fwd_cuda_rng_state
,
ctx
.
fwd_cuda_rng_state_tracker
):
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
custom_white_list
=
ctx
.
amp_white_list
,
custom_black_list
=
ctx
.
amp_black_list
,
level
=
ctx
.
amp_level
):
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
if
isinstance
(
outputs
,
core
.
eager
.
Tensor
):
outputs
=
(
outputs
,
)
assert
len
(
outputs
)
==
len
(
args
)
forward_outputs_with_grad
=
[]
backward_inputs
=
[]
for
i
in
range
(
len
(
outputs
)):
if
isinstance
(
outputs
[
i
],
core
.
eager
.
Tensor
)
and
not
outputs
[
i
].
stop_gradient
:
forward_outputs_with_grad
.
append
(
outputs
[
i
])
backward_inputs
.
append
(
args
[
i
])
if
len
(
forward_outputs_with_grad
)
==
0
:
raise
RuntimeError
(
"none of output has stop_gradient=False, this recompute() is not necessary"
)
# actually backward
paddle
.
autograd
.
backward
(
forward_outputs_with_grad
,
backward_inputs
)
grads
=
tuple
(
inp
.
_grad_ivar
()
for
inp
in
detached_inputs
if
isinstance
(
inp
,
core
.
eager
.
Tensor
))
return
grads
class
_HPRecomputeFunction
(
PyLayer
):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
...
...
@@ -290,8 +423,8 @@ class _HPRecomputeFunction(PyLayer):
# actually backward
paddle
.
autograd
.
backward
(
forward_outputs_with_grad
,
backward_inputs
)
grads
=
list
(
inp
.
_grad_ivar
()
for
inp
in
detached_inputs
if
isinstance
(
inp
,
core
.
VarBase
))
grads
=
tuple
(
inp
.
_grad_ivar
()
for
inp
in
detached_inputs
if
isinstance
(
inp
,
core
.
VarBase
))
return
grads
...
...
@@ -303,7 +436,10 @@ def _hp_recompute(function, *args):
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor
all_outputs
=
[]
_HPRecomputeFunction
.
apply
(
function
,
all_outputs
,
*
args
)
if
in_dygraph_mode
():
_HPEagerRecomputeFunction
.
apply
(
function
,
all_outputs
,
*
args
)
else
:
_HPRecomputeFunction
.
apply
(
function
,
all_outputs
,
*
args
)
if
len
(
all_outputs
)
==
1
:
return
all_outputs
[
0
]
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
7eae6570
...
...
@@ -1137,7 +1137,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties
(
test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 350
)
set_tests_properties
(
test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 300
)
set_tests_properties
(
test_parallel_dygraph_no_sync_gradient_check PROPERTIES TIMEOUT 30
)
set_tests_properties
(
test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT
2
00
)
set_tests_properties
(
test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT
5
00
)
set_tests_properties
(
test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py
浏览文件 @
7eae6570
...
...
@@ -112,10 +112,11 @@ class TestDistPPTraning(unittest.TestCase):
with
paddle
.
amp
.
auto_cast
():
loss_a
=
model_a
(
img
,
label
)
scaler_a
.
scale
(
loss_a
).
backward
()
scaler_a
.
minimize
(
optimizer_a
,
loss_a
)
optimizer_a
.
clear_grad
()
scheduler_a
.
step
()
scaler_a
.
scale
(
loss_a
).
backward
()
scaler_a
.
minimize
(
optimizer_a
,
loss_a
)
optimizer_a
.
clear_grad
()
scheduler_a
.
step
()
with
paddle
.
amp
.
auto_cast
():
loss_b
=
model_b
.
train_batch
(
...
...
python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
浏览文件 @
7eae6570
...
...
@@ -124,12 +124,12 @@ class TestDistPPTraning(unittest.TestCase):
with
paddle
.
amp
.
auto_cast
(
enable
=
True
,
level
=
'O2'
):
loss_a
=
model_a
(
img
,
label
)
scaler_a
.
scale
(
loss_a
).
backward
()
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
scaler_a
.
minimize
(
optimizer_a
,
loss_a
)
optimizer_a
.
clear_grad
()
scheduler_a
.
step
()
scaler_a
.
scale
(
loss_a
).
backward
()
scaler_a
.
minimize
(
optimizer_a
,
loss_a
)
optimizer_a
.
clear_grad
()
scheduler_a
.
step
()
with
paddle
.
amp
.
auto_cast
(
enable
=
True
,
level
=
'O2'
):
loss_b
=
model_b
.
train_batch
(
[
img
,
label
],
optimizer_b
,
scheduler_b
,
scaler
=
scaler_b
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py
浏览文件 @
7eae6570
...
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
unittest
import
paddle.fluid
as
fluid
import
os
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
...
...
@@ -23,31 +24,43 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class
TestHybridPipeParallel
(
TestMultipleGpus
):
def
test_hybrid_parallel_pp_layer
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_layer.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_layer.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_pp_tuple_inputs
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_embedding.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_embedding.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_shared_weight
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_shared_weight.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_shared_weight.py'
,
eager_mode
=
False
)
def
test_pipeline_parallel_amp
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_amp.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_amp.py'
,
eager_mode
=
False
)
def
test_pipeline_parallel_fp16
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_fp16.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_fp16.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_transformer
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_transformer.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_transformer.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_save_load
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_save_load.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_save_load.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_recompute
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_recompute.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_recompute.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_pp_clip_grad
(
self
):
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_clip_grad.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_pp_clip_grad.py'
,
eager_mode
=
False
)
if
__name__
==
"__main__"
:
os
.
environ
[
"FLAGS_enable_eager_mode"
]
=
"1"
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_dygraph_tensor_parallel.py
浏览文件 @
7eae6570
...
...
@@ -23,7 +23,7 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class
TestHybridParallel
(
TestMultipleGpus
):
def
test_hybrid_parallel_mp_random
(
self
):
#
self.run_mnist_2gpu('hybrid_parallel_mp_random.py')
self
.
run_mnist_2gpu
(
'hybrid_parallel_mp_random.py'
)
self
.
run_mnist_2gpu
(
'hybrid_parallel_mp_random.py'
,
eager_mode
=
False
)
def
test_hybrid_parallel_mp_model
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录