Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bb2310a6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bb2310a6
编写于
8月 31, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
8月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
recompute support tuple (#56793)
上级
23bc4c26
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
155 addition
and
16 deletion
+155
-16
python/paddle/distributed/fleet/recompute/recompute.py
python/paddle/distributed/fleet/recompute/recompute.py
+66
-16
test/legacy_test/test_recompute_with_tuple_input.py
test/legacy_test/test_recompute_with_tuple_input.py
+89
-0
未找到文件。
python/paddle/distributed/fleet/recompute/recompute.py
浏览文件 @
bb2310a6
...
@@ -31,10 +31,24 @@ __all__ = []
...
@@ -31,10 +31,24 @@ __all__ = []
def
detach_variable
(
inputs
):
def
detach_variable
(
inputs
):
out
=
[]
out
=
[]
for
inp
in
inputs
:
for
inp
in
inputs
:
if
not
isinstance
(
inp
,
core
.
eager
.
Tensor
):
if
not
isinstance
(
inp
,
core
.
eager
.
Tensor
)
and
(
type
(
inp
)
is
not
tuple
or
not
isinstance
(
inp
[
0
],
core
.
eager
.
Tensor
)
):
# the inp is not a tensor or not a tuple of tensors
out
.
append
(
inp
)
out
.
append
(
inp
)
continue
continue
if
type
(
inp
)
is
tuple
:
detach_inp
=
[]
for
i
in
inp
:
# detach all tensors in the tuple
assert
isinstance
(
i
,
core
.
eager
.
Tensor
)
tmp_i
=
i
.
detach
()
tmp_i
.
stop_gradient
=
i
.
stop_gradient
detach_inp
.
append
(
tmp_i
)
out
.
append
(
tuple
(
detach_inp
))
continue
x
=
inp
.
detach
()
x
=
inp
.
detach
()
x
.
stop_gradient
=
inp
.
stop_gradient
x
.
stop_gradient
=
inp
.
stop_gradient
out
.
append
(
x
)
out
.
append
(
x
)
...
@@ -42,11 +56,16 @@ def detach_variable(inputs):
...
@@ -42,11 +56,16 @@ def detach_variable(inputs):
def
check_recompute_necessary
(
inputs
):
def
check_recompute_necessary
(
inputs
):
if
not
any
(
necessary_for_each_input
=
[]
not
input_
.
stop_gradient
for
input_
in
inputs
:
for
input_
in
inputs
if
isinstance
(
input_
,
(
core
.
eager
.
Tensor
,
paddle
.
Tensor
)):
if
isinstance
(
input_
,
(
core
.
eager
.
Tensor
,
paddle
.
Tensor
))
necessary_for_each_input
.
append
(
input_
.
stop_gradient
)
):
elif
type
(
input_
)
is
tuple
:
for
i
in
input_
:
# traverse all tensors in the tuple
if
isinstance
(
i
,
(
core
.
eager
.
Tensor
,
paddle
.
Tensor
)):
necessary_for_each_input
.
append
(
i
.
stop_gradient
)
if
all
(
necessary_for_each_input
):
logger
.
warning
(
logger
.
warning
(
"[Recompute]: None of the inputs to current recompute block need grad, "
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !"
"therefore there is NO need to recompute this block in backward !"
...
@@ -81,12 +100,37 @@ class RecomputeFunction(PyLayer):
...
@@ -81,12 +100,37 @@ class RecomputeFunction(PyLayer):
# save input for backward
# save input for backward
ctx
.
inputs
=
[]
ctx
.
inputs
=
[]
ctx
.
tensor_indices
=
[]
ctx
.
tensor_indices
=
[]
ctx
.
duplicate_tensor
=
[
False
for
_
in
range
(
len
(
args
))]
tensor_inputs
=
[]
tensor_inputs
=
[]
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
if
paddle
.
is_tensor
(
arg
):
if
paddle
.
is_tensor
(
arg
):
tensor_inputs
.
append
(
arg
)
tensor_inputs
.
append
(
arg
)
ctx
.
tensor_indices
.
append
(
i
)
ctx
.
tensor_indices
.
append
(
i
)
ctx
.
inputs
.
append
(
None
)
ctx
.
inputs
.
append
(
None
)
elif
type
(
arg
)
is
tuple
:
is_tensors
=
[
paddle
.
is_tensor
(
a
)
for
a
in
arg
]
if
all
(
is_tensors
):
# the tuple is a tuple of tensors
tensors_stop_gradient
=
[
a
.
stop_gradient
for
a
in
arg
]
if
not
all
(
tensors_stop_gradient
)
and
any
(
tensors_stop_gradient
):
# tensors in the tuple have different stop_gradient value, which pylayer doesn't support
raise
ValueError
(
"Recompute receive a tuple containing tensor holds different stop gradient."
)
tensor_inputs
.
append
(
arg
)
ctx
.
tensor_indices
.
append
(
i
)
# Mark the tuple is a tuple of tensors
ctx
.
duplicate_tensor
[
i
]
=
True
ctx
.
inputs
.
append
(
None
)
elif
any
(
is_tensors
):
# the tuple contains tensors and non-tensor values
raise
ValueError
(
"Recompute receive a tuple containing tensor and non-tensor at same time."
)
else
:
ctx
.
inputs
.
append
(
arg
)
else
:
else
:
ctx
.
inputs
.
append
(
arg
)
ctx
.
inputs
.
append
(
arg
)
ctx
.
save_for_backward
(
*
tensor_inputs
)
ctx
.
save_for_backward
(
*
tensor_inputs
)
...
@@ -132,6 +176,7 @@ class RecomputeFunction(PyLayer):
...
@@ -132,6 +176,7 @@ class RecomputeFunction(PyLayer):
# Restore inputs
# Restore inputs
inputs
=
list
(
ctx
.
inputs
)
inputs
=
list
(
ctx
.
inputs
)
tensor_indices
=
ctx
.
tensor_indices
tensor_indices
=
ctx
.
tensor_indices
duplicate_tensor
=
ctx
.
duplicate_tensor
tensors
=
ctx
.
saved_tensor
()
tensors
=
ctx
.
saved_tensor
()
for
i
,
idx
in
enumerate
(
tensor_indices
):
for
i
,
idx
in
enumerate
(
tensor_indices
):
inputs
[
idx
]
=
tensors
[
i
]
inputs
[
idx
]
=
tensors
[
i
]
...
@@ -198,18 +243,23 @@ class RecomputeFunction(PyLayer):
...
@@ -198,18 +243,23 @@ class RecomputeFunction(PyLayer):
forward_outputs_with_grad
,
backward_inputs_with_grad
forward_outputs_with_grad
,
backward_inputs_with_grad
)
)
grads
=
[]
for
idx
,
inp
in
enumerate
(
detached_inputs
):
if
isinstance
(
inp
,
core
.
eager
.
Tensor
):
grads
.
append
(
inp
.
_grad_ivar
())
elif
type
(
inp
)
is
tuple
and
duplicate_tensor
[
idx
]:
# input is a tuple and is a tuple of tensors
if
all
(
i
.
stop_gradient
for
i
in
inp
):
# all tensors in the tuple doesn't need grad, only return a None for the whole tuple
grads
.
append
(
None
)
else
:
# all tensors in the tuple nees grad, should return a tuple of grads
grads
.
append
(
tuple
(
i
.
_grad_ivar
()
for
i
in
inp
))
if
in_dynamic_mode
():
if
in_dynamic_mode
():
grads
=
tuple
(
grads
=
tuple
(
grads
)
inp
.
_grad_ivar
()
for
inp
in
detached_inputs
if
isinstance
(
inp
,
core
.
eager
.
Tensor
)
)
else
:
else
:
grads
=
[
grads
=
list
(
grads
)
inp
.
_grad_ivar
()
for
inp
in
detached_inputs
if
isinstance
(
inp
,
core
.
eager
.
Tensor
)
]
return
grads
return
grads
...
...
test/legacy_test/test_recompute_with_tuple_input.py
0 → 100644
浏览文件 @
bb2310a6
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
from
paddle.distributed.fleet.utils
import
recompute
class
Layer
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
self
.
linear1
=
paddle
.
nn
.
Linear
(
10
,
10
)
self
.
linear2
=
paddle
.
nn
.
Linear
(
10
,
10
)
self
.
linear3
=
paddle
.
nn
.
Linear
(
10
,
10
)
self
.
silu1
=
paddle
.
nn
.
Silu
()
self
.
silu2
=
paddle
.
nn
.
Silu
()
self
.
silu3
=
paddle
.
nn
.
Silu
()
def
forward
(
self
,
x
,
y
):
assert
type
(
x
)
is
tuple
assert
len
(
x
)
==
2
o1
=
self
.
silu1
(
self
.
linear1
(
x
[
0
]))
o2
=
self
.
silu2
(
self
.
linear2
(
x
[
1
]))
o3
=
self
.
silu3
(
self
.
linear3
(
y
))
o
=
o1
+
o2
+
o3
return
o
class
TestPyLayer
(
unittest
.
TestCase
):
def
test_tuple_input
(
self
):
layer
=
Layer
()
x1
=
paddle
.
rand
(
shape
=
[
10
,
10
])
x1
.
stop_gradient
=
False
x2
=
paddle
.
rand
(
shape
=
[
10
,
10
])
x2
.
stop_gradient
=
False
y
=
paddle
.
rand
(
shape
=
[
10
,
10
])
y
.
stop_gradient
=
False
o
=
recompute
(
layer
,
(
x1
,
x2
),
y
)
loss
=
paddle
.
mean
(
o
,
keepdim
=
True
)
loss
.
backward
()
def
test_tuple_input_with_non_tensor
(
self
):
layer
=
Layer
()
x1
=
paddle
.
rand
(
shape
=
[
10
,
10
])
x1
.
stop_gradient
=
False
y
=
paddle
.
rand
(
shape
=
[
10
,
10
])
y
.
stop_gradient
=
False
try
:
o
=
recompute
(
layer
,
(
x1
,
True
),
y
)
except
ValueError
:
pass
def
test_tuple_input_with_different_stop_gradient
(
self
):
layer
=
Layer
()
x1
=
paddle
.
rand
(
shape
=
[
10
,
10
])
x1
.
stop_gradient
=
False
x2
=
paddle
.
rand
(
shape
=
[
10
,
10
])
y
=
paddle
.
rand
(
shape
=
[
10
,
10
])
y
.
stop_gradient
=
False
try
:
o
=
recompute
(
layer
,
(
x1
,
True
),
y
)
except
ValueError
:
pass
def
test_tuple_input_all_no_gradient
(
self
):
layer
=
Layer
()
x1
=
paddle
.
rand
(
shape
=
[
10
,
10
])
x2
=
paddle
.
rand
(
shape
=
[
10
,
10
])
y
=
paddle
.
rand
(
shape
=
[
10
,
10
])
y
.
stop_gradient
=
False
o
=
recompute
(
layer
,
(
x1
,
x2
),
y
)
loss
=
paddle
.
mean
(
o
,
keepdim
=
True
)
loss
.
backward
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录