Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
64f780c0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
64f780c0
编写于
12月 20, 2022
作者:
W
wuhuachaocoding
提交者:
GitHub
12月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Recompute upgrade (#47985)
上级
c830a28e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
253 addition
and
97 deletion
+253
-97
python/paddle/distributed/fleet/recompute/recompute.py
python/paddle/distributed/fleet/recompute/recompute.py
+135
-20
python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py
...ests/unittests/collective/fleet/test_dygraph_recompute.py
+1
-1
python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py
...ests/collective/fleet/test_dygraph_recompute_for_eager.py
+117
-76
未找到文件。
python/paddle/distributed/fleet/recompute/recompute.py
浏览文件 @
64f780c0
...
@@ -13,12 +13,16 @@
...
@@ -13,12 +13,16 @@
# limitations under the License.
# limitations under the License.
import
contextlib
import
contextlib
import
weakref
import
paddle
import
paddle
from
paddle
import
framework
from
paddle.autograd
import
PyLayer
from
paddle.autograd
import
PyLayer
from
paddle.autograd.py_layer
import
LegacyPyLayer
from
paddle.autograd.py_layer
import
LegacyPyLayer
from
paddle.fluid
import
core
,
framework
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
(
from
paddle.fluid.framework
import
in_dygraph_mode
get_rng_state_tracker
,
)
from
paddle.framework
import
core
,
in_dygraph_mode
from
..utils.log_util
import
logger
from
..utils.log_util
import
logger
...
@@ -52,10 +56,6 @@ def check_recompute_necessary(inputs):
...
@@ -52,10 +56,6 @@ def check_recompute_necessary(inputs):
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
swith_rng_state_tracker
(
rng_state
,
tracker
):
def
swith_rng_state_tracker
(
rng_state
,
tracker
):
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
(
get_rng_state_tracker
,
)
orig_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
orig_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
orig_cuda_rng_tracker
=
get_rng_state_tracker
().
get_states_tracker
()
orig_cuda_rng_tracker
=
get_rng_state_tracker
().
get_states_tracker
()
...
@@ -71,10 +71,6 @@ def swith_rng_state_tracker(rng_state, tracker):
...
@@ -71,10 +71,6 @@ def swith_rng_state_tracker(rng_state, tracker):
class
LegacyRecomputeFunction
(
LegacyPyLayer
):
class
LegacyRecomputeFunction
(
LegacyPyLayer
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
(
get_rng_state_tracker
,
)
# store for recomputing
# store for recomputing
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
preserve_rng_state
=
preserve_rng_state
ctx
.
preserve_rng_state
=
preserve_rng_state
...
@@ -223,10 +219,6 @@ class LegacyRecomputeFunction(LegacyPyLayer):
...
@@ -223,10 +219,6 @@ class LegacyRecomputeFunction(LegacyPyLayer):
class
RecomputeFunction
(
PyLayer
):
class
RecomputeFunction
(
PyLayer
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
,
**
kwargs
):
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
,
**
kwargs
):
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
(
get_rng_state_tracker
,
)
# store for recomputing
# store for recomputing
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
preserve_rng_state
=
preserve_rng_state
ctx
.
preserve_rng_state
=
preserve_rng_state
...
@@ -382,6 +374,116 @@ class RecomputeFunction(PyLayer):
...
@@ -382,6 +374,116 @@ class RecomputeFunction(PyLayer):
return
grads
return
grads
def
_recompute_without_reentrant
(
function
,
preserve_rng_state
=
True
,
*
args
,
**
kwargs
):
"""
recompute without reentrant, that means use hook to implement the recompute function rather than re-entrant autograd.
"""
if
preserve_rng_state
:
cur_device
=
paddle
.
get_device
()
if
'gpu:'
not
in
cur_device
:
raise
RuntimeError
(
"Recompute with RNG perserve is not support current device: {}."
.
format
(
cur_device
)
)
fw_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
fwd_cuda_rng_state_tracker
=
(
get_rng_state_tracker
().
get_states_tracker
()
)
tracer
=
framework
.
_dygraph_tracer
()
is_fw_autocast
=
False
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O0
else
True
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O2
:
amp_level
=
'O2'
elif
tracer
.
_amp_level
in
(
core
.
AmpLevel
.
O1
,
core
.
AmpLevel
.
O0
):
amp_level
=
'O1'
if
tracer
.
_amp_dtype
==
'float16'
:
amp_dtype
=
'float16'
elif
tracer
.
_amp_dtype
in
(
'bfloat16'
,
'float32'
):
amp_dtype
=
'bfloat16'
amp_white_list
,
amp_black_list
=
tracer
.
_get_amp_op_list
()
class
Intermediate_Holder
:
pass
storage
=
weakref
.
WeakKeyDictionary
()
holder_list
=
[]
def
pack
(
x
):
res
=
Intermediate_Holder
()
holder_list
.
append
(
weakref
.
ref
(
res
))
return
res
def
unpack
(
x
):
unpack_counter
=
0
if
len
(
storage
)
==
0
:
def
inner_pack
(
inner_x
):
nonlocal
unpack_counter
unpack_counter
+=
1
if
holder_list
[
unpack_counter
-
1
]()
is
None
:
return
tmp_tensor
=
core
.
eager
.
Tensor
(
inner_x
.
dtype
,
inner_x
.
shape
,
inner_x
.
name
+
"cpy"
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
inner_x
.
persistable
,
)
inner_x
.
_share_buffer_to
(
tmp_tensor
)
storage
[
holder_list
[
unpack_counter
-
1
]()]
=
tmp_tensor
return
def
inner_unpack
(
inner_x
):
raise
Exception
(
"An unexcepted backward called on a tensor!"
)
if
preserve_rng_state
:
with
swith_rng_state_tracker
(
fw_cuda_rng_state
,
fwd_cuda_rng_state_tracker
):
with
paddle
.
set_grad_enabled
(
True
):
with
paddle
.
amp
.
auto_cast
(
enable
=
is_fw_autocast
,
custom_white_list
=
amp_white_list
,
custom_black_list
=
amp_black_list
,
level
=
amp_level
,
dtype
=
amp_dtype
,
):
with
paddle
.
autograd
.
saved_tensors_hooks
(
inner_pack
,
inner_unpack
):
unused_outputs
=
function
(
*
args
,
**
kwargs
)
else
:
with
paddle
.
set_grad_enabled
(
True
),
paddle
.
amp
.
auto_cast
(
enable
=
is_fw_autocast
,
custom_white_list
=
amp_white_list
,
custom_black_list
=
amp_black_list
,
level
=
amp_level
,
dtype
=
amp_dtype
,
),
paddle
.
autograd
.
saved_tensors_hooks
(
inner_pack
,
inner_unpack
):
unused_outputs
=
function
(
*
args
,
**
kwargs
)
if
x
not
in
storage
:
raise
Exception
(
"Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute."
)
return
storage
[
x
]
with
paddle
.
autograd
.
saved_tensors_hooks
(
pack
,
unpack
):
outputs
=
function
(
*
args
,
**
kwargs
)
return
outputs
def
recompute
(
function
,
*
args
,
**
kwargs
):
def
recompute
(
function
,
*
args
,
**
kwargs
):
"""
"""
recompute intermediate activations to save then memory.
recompute intermediate activations to save then memory.
...
@@ -391,11 +493,13 @@ def recompute(function, *args, **kwargs):
...
@@ -391,11 +493,13 @@ def recompute(function, *args, **kwargs):
whose intermediate activations will be released to save memory in forward stage and will be recomputed
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
in backward stage for gradient calculation.
*args(Tensor): inputs to the function.
*args(Tensor): inputs to the function.
**kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to
**kwargs(Dict): Kwargs should only contain two kinds of key-value params, the one is part of function's key-value params,
indicate whether to save the forward rng. If it is True, then the last forward rng value will be
and the other contains 'preserve_rng_state' and 'use_reentrant'. the key-value pair of preserve_rng_state,
restored when the forward recalculation of backpropagation is performed. The default
which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value
preserve_rng_state is True.
will be restored when the forward recalculation of backpropagation is performed, its default value is True.
the key-value pair of use_reentrant is used to indicate which implementation of recompute you will be used.
'use_reentrant=True' means to use the PyLayer implementation of recompute, 'use_reentrant=False' means to
use the Hook implementation of recompute, its default value is True.
Returns:
Returns:
Output of function on args.
Output of function on args.
...
@@ -487,10 +591,21 @@ def recompute(function, *args, **kwargs):
...
@@ -487,10 +591,21 @@ def recompute(function, *args, **kwargs):
# Hack to mix *args with **kwargs in a python 2.7-compliant way
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve
=
kwargs
.
pop
(
'preserve_rng_state'
,
True
)
preserve
=
kwargs
.
pop
(
'preserve_rng_state'
,
True
)
# whether to use reentrant method to implement recompute
use_reentrant
=
kwargs
.
pop
(
'use_reentrant'
,
True
)
if
kwargs
and
use_reentrant
:
raise
ValueError
(
"Error, if you want to send kwargs(dict parameter) to function, please set use_reentrant=False."
)
if
framework
.
_dygraph_tracer
().
_has_grad
:
if
framework
.
_dygraph_tracer
().
_has_grad
:
check_recompute_necessary
(
args
)
check_recompute_necessary
(
args
)
return
RecomputeFunction
.
apply
(
function
,
preserve
,
*
args
,
**
kwargs
)
if
use_reentrant
:
return
RecomputeFunction
.
apply
(
function
,
preserve
,
*
args
)
else
:
return
_recompute_without_reentrant
(
function
,
preserve
,
*
args
,
**
kwargs
)
def
recompute_sequential
(
ctx
,
functions
,
*
args
,
**
kwargs
):
def
recompute_sequential
(
ctx
,
functions
,
*
args
,
**
kwargs
):
...
...
python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py
浏览文件 @
64f780c0
...
@@ -272,7 +272,7 @@ class TestPyLayer(unittest.TestCase):
...
@@ -272,7 +272,7 @@ class TestPyLayer(unittest.TestCase):
def
test_recompute_kwargs
(
self
):
def
test_recompute_kwargs
(
self
):
paddle
.
set_device
(
"gpu"
)
paddle
.
set_device
(
"gpu"
)
kwargs
=
{
"is_test"
:
False
}
kwargs
=
{
"is_test"
:
False
}
with
self
.
assertRaises
(
Typ
eError
):
with
self
.
assertRaises
(
Valu
eError
):
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
],
recompute_kwargs
=
kwargs
recompute_block
=
[
2
],
recompute_kwargs
=
kwargs
)
)
...
...
python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py
浏览文件 @
64f780c0
...
@@ -21,32 +21,44 @@ import paddle
...
@@ -21,32 +21,44 @@ import paddle
from
paddle.distributed.fleet.utils
import
recompute
from
paddle.distributed.fleet.utils
import
recompute
class
Model
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
block_idx
,
input_size
,
is_last
=
False
):
super
(
Model
,
self
).
__init__
()
block_name
=
"block_"
+
str
(
block_idx
)
self
.
block
=
paddle
.
nn
.
Sequential
(
(
block_name
+
"_fc_0"
,
paddle
.
nn
.
Linear
(
input_size
,
input_size
,
bias_attr
=
False
),
),
(
block_name
+
"_dropout"
,
paddle
.
nn
.
Dropout
(
p
=
0.5
)),
(
block_name
+
"_relu_1"
,
paddle
.
nn
.
ReLU
()),
(
block_name
+
"_fc_1"
,
paddle
.
nn
.
Linear
(
input_size
,
input_size
,
bias_attr
=
False
),
),
(
block_name
+
"_relu_2"
,
paddle
.
nn
.
ReLU
()),
)
if
is_last
:
self
.
block
.
add_sublayer
(
block_name
+
"_fc_2"
,
paddle
.
nn
.
Linear
(
input_size
,
1
,
bias_attr
=
False
),
)
# add sublayer
else
:
self
.
block
.
add_sublayer
(
block_name
+
"_fc_2"
,
paddle
.
nn
.
Linear
(
input_size
,
input_size
,
bias_attr
=
False
),
)
# add sublayer
# add pos param for test kwargs of recompute.
def
forward
(
self
,
x
,
pos
=
None
):
if
pos
is
None
:
return
self
.
block
(
x
)
else
:
return
self
.
block
(
x
)
+
pos
def
get_fc_block
(
block_idx
,
input_size
,
is_last
=
False
):
def
get_fc_block
(
block_idx
,
input_size
,
is_last
=
False
):
block_name
=
"block_"
+
str
(
block_idx
)
return
Model
(
block_idx
,
input_size
,
is_last
=
False
)
block
=
paddle
.
nn
.
Sequential
(
(
block_name
+
"_fc_0"
,
paddle
.
nn
.
Linear
(
input_size
,
input_size
,
bias_attr
=
False
),
),
(
block_name
+
"_dropout"
,
paddle
.
nn
.
Dropout
(
p
=
0.5
)),
(
block_name
+
"_relu_1"
,
paddle
.
nn
.
ReLU
()),
(
block_name
+
"_fc_1"
,
paddle
.
nn
.
Linear
(
input_size
,
input_size
,
bias_attr
=
False
),
),
(
block_name
+
"_relu_2"
,
paddle
.
nn
.
ReLU
()),
)
if
is_last
:
block
.
add_sublayer
(
block_name
+
"_fc_2"
,
paddle
.
nn
.
Linear
(
input_size
,
1
,
bias_attr
=
False
),
)
# add sublayer
else
:
block
.
add_sublayer
(
block_name
+
"_fc_2"
,
paddle
.
nn
.
Linear
(
input_size
,
input_size
,
bias_attr
=
False
),
)
# add sublayer
return
block
class
Naive_fc_net
(
paddle
.
nn
.
Layer
):
class
Naive_fc_net
(
paddle
.
nn
.
Layer
):
...
@@ -143,6 +155,10 @@ def run_model(
...
@@ -143,6 +155,10 @@ def run_model(
segments
=
segments
,
segments
=
segments
,
recompute_kwargs
=
recompute_kwargs
,
recompute_kwargs
=
recompute_kwargs
,
)
)
if
pure_fp16
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
loss_fn
=
paddle
.
nn
.
MSELoss
(
reduction
=
'mean'
)
loss_fn
=
paddle
.
nn
.
MSELoss
(
reduction
=
'mean'
)
optimizer
=
paddle
.
optimizer
.
SGD
(
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.01
,
parameters
=
model
.
parameters
()
learning_rate
=
0.01
,
parameters
=
model
.
parameters
()
...
@@ -158,7 +174,7 @@ def run_model(
...
@@ -158,7 +174,7 @@ def run_model(
x_data
=
np
.
random
.
randn
(
batch_size
,
input_size
).
astype
(
np
.
float32
)
x_data
=
np
.
random
.
randn
(
batch_size
,
input_size
).
astype
(
np
.
float32
)
x
=
paddle
.
to_tensor
(
x_data
)
x
=
paddle
.
to_tensor
(
x_data
)
#
x.stop_gradient = False
x
.
stop_gradient
=
False
level
=
'O2'
if
pure_fp16
else
'O1'
level
=
'O2'
if
pure_fp16
else
'O1'
with
paddle
.
amp
.
auto_cast
(
True
,
level
=
level
):
with
paddle
.
amp
.
auto_cast
(
True
,
level
=
level
):
y_pred
=
model
(
x
)
y_pred
=
model
(
x
)
...
@@ -178,7 +194,7 @@ def run_model(
...
@@ -178,7 +194,7 @@ def run_model(
return
loss_
,
param_
,
grad_
return
loss_
,
param_
,
grad_
class
Test
PyLayer
(
unittest
.
TestCase
):
class
Test
Recompute
(
unittest
.
TestCase
):
def
test_base_case
(
self
,
enable_autocast
=
False
,
pure_fp16
=
False
):
def
test_base_case
(
self
,
enable_autocast
=
False
,
pure_fp16
=
False
):
def
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
):
def
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
):
self
.
assertEqual
(
loss_ref
,
loss
)
self
.
assertEqual
(
loss_ref
,
loss
)
...
@@ -192,46 +208,55 @@ class TestPyLayer(unittest.TestCase):
...
@@ -192,46 +208,55 @@ class TestPyLayer(unittest.TestCase):
pure_fp16
=
pure_fp16
,
pure_fp16
=
pure_fp16
,
)
)
# recompute second block
# test for recompute
loss
,
param
,
grad
=
run_model
(
# True: PyLayer of recompute
recompute_block
=
[
1
],
# False: HooK of recompute
enable_autocast
=
enable_autocast
,
for
flag
in
[
True
,
False
]:
pure_fp16
=
pure_fp16
,
# recompute second block
)
loss
,
param
,
grad
=
run_model
(
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
recompute_block
=
[
1
],
enable_autocast
=
enable_autocast
,
# recompute fourth block
pure_fp16
=
pure_fp16
,
loss
,
param
,
grad
=
run_model
(
recompute_kwargs
=
{
"use_reentrant"
:
flag
},
recompute_block
=
[
3
],
)
enable_autocast
=
enable_autocast
,
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
pure_fp16
=
pure_fp16
,
)
# recompute fourth block
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
3
],
# recompute second to fourth block
enable_autocast
=
enable_autocast
,
loss
,
param
,
grad
=
run_model
(
pure_fp16
=
pure_fp16
,
recompute_block
=
[
1
,
2
,
3
],
recompute_kwargs
=
{
"use_reentrant"
:
flag
},
enable_autocast
=
enable_autocast
,
)
pure_fp16
=
pure_fp16
,
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute second to fourth block
loss
,
param
,
grad
=
run_model
(
# recompute second & fourth block
recompute_block
=
[
1
,
2
,
3
],
loss
,
param
,
grad
=
run_model
(
enable_autocast
=
enable_autocast
,
recompute_block
=
[
1
,
3
],
pure_fp16
=
pure_fp16
,
enable_autocast
=
enable_autocast
,
recompute_kwargs
=
{
"use_reentrant"
:
flag
},
pure_fp16
=
pure_fp16
,
)
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute second & fourth block
# recompute_sequential with segments=1 using fleet
loss
,
param
,
grad
=
run_model
(
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
,
3
],
recompute_block
=
[],
enable_autocast
=
enable_autocast
,
use_fleet_sq
=
True
,
pure_fp16
=
pure_fp16
,
enable_autocast
=
enable_autocast
,
recompute_kwargs
=
{
"use_reentrant"
:
flag
},
pure_fp16
=
pure_fp16
,
)
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute_sequential with segments=1 using fleet
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[],
use_fleet_sq
=
True
,
enable_autocast
=
enable_autocast
,
pure_fp16
=
pure_fp16
,
recompute_kwargs
=
{
"use_reentrant"
:
flag
},
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# with base recompute, and segments=2
# with base recompute, and segments=2
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
...
@@ -255,11 +280,15 @@ class TestPyLayer(unittest.TestCase):
...
@@ -255,11 +280,15 @@ class TestPyLayer(unittest.TestCase):
self
.
test_base_case
()
self
.
test_base_case
()
def
test_fc_net_without_restore_rng
(
self
):
def
test_fc_net_without_restore_rng
(
self
):
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
for
flag
in
[
True
,
False
]:
recompute_block
=
[
2
],
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_kwargs
=
{
"preserve_rng_state"
:
False
},
recompute_block
=
[
2
],
enable_autocast
=
True
,
recompute_kwargs
=
{
)
"preserve_rng_state"
:
False
,
"use_reentrant"
:
flag
,
},
enable_autocast
=
True
,
)
def
test_fc_net_with_amp
(
self
):
def
test_fc_net_with_amp
(
self
):
self
.
test_base_case
(
enable_autocast
=
True
)
self
.
test_base_case
(
enable_autocast
=
True
)
...
@@ -269,16 +298,28 @@ class TestPyLayer(unittest.TestCase):
...
@@ -269,16 +298,28 @@ class TestPyLayer(unittest.TestCase):
def
test_recompute_kwargs
(
self
):
def
test_recompute_kwargs
(
self
):
paddle
.
set_device
(
"gpu"
)
paddle
.
set_device
(
"gpu"
)
kwargs
=
{
"is_test"
:
False
}
pos
=
paddle
.
randn
(
shape
=
[
10
,
10
],
dtype
=
"float32"
)
with
self
.
assertRaises
(
TypeError
):
pos
.
stop_gradient
=
False
kwargs
=
{
"pos"
:
pos
,
"use_reentrant"
:
True
}
with
self
.
assertRaises
(
ValueError
):
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
],
recompute_kwargs
=
kwargs
recompute_block
=
[
2
],
recompute_kwargs
=
kwargs
)
)
kwargs
=
{
"pos"
:
pos
,
"use_reentrant"
:
False
}
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
],
recompute_kwargs
=
kwargs
)
def
test_recompute_cpu_rng
(
self
):
def
test_recompute_cpu_rng
(
self
):
paddle
.
set_device
(
"cpu"
)
paddle
.
set_device
(
"cpu"
)
with
self
.
assertRaises
(
RuntimeError
):
for
flag
in
[
True
,
False
]:
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
])
with
self
.
assertRaises
(
RuntimeError
):
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
],
recompute_kwargs
=
{
"use_reentrant"
:
flag
},
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录