Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a1d9a14e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a1d9a14e
编写于
12月 28, 2020
作者:
C
Chen Weihang
提交者:
GitHub
12月 29, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support grad accumulated across batch (#29942)
上级
bb20dcfc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
31 addition
and
7 deletion
+31
-7
paddle/fluid/imperative/gradient_accumulator.h
paddle/fluid/imperative/gradient_accumulator.h
+1
-0
python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py
...le/fluid/tests/unittests/test_complex_grad_accumulated.py
+30
-7
未找到文件。
paddle/fluid/imperative/gradient_accumulator.h
浏览文件 @
a1d9a14e
...
...
@@ -45,6 +45,7 @@ class GradientAccumulator {
inner_var_
=
std
::
make_shared
<
VariableWrapper
>
(
var
->
Name
());
inner_var_
->
SetType
(
var
->
Type
());
inner_var_
->
SetDataType
(
var
->
DataType
());
inner_var_
->
SetForwardDataType
(
var
->
ForwardDataType
());
inner_var_
->
InnerSetOverridedStopGradient
(
var
->
InnerOverridedStopGradient
());
VLOG
(
6
)
<<
" Create inner grad var for ("
<<
var
->
Name
()
...
...
python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py
浏览文件 @
a1d9a14e
...
...
@@ -41,7 +41,6 @@ class Optimization_ex1(paddle.nn.Layer):
np
.
random
.
random
((
4
,
4
)).
astype
(
dtype
)
+
np
.
random
.
random
(
(
4
,
4
)).
astype
(
dtype
)
*
1j
,
stop_gradient
=
False
)
print
(
self
.
A
)
def
forward
(
self
,
mode
=
1
):
jj
=
paddle
.
to_tensor
(
np
.
array
([
1j
]).
astype
(
np
.
complex64
))
...
...
@@ -70,31 +69,55 @@ class TestComplexGradAccumulated(unittest.TestCase):
self
.
devices
=
[
'cpu'
]
if
core
.
is_compiled_with_cuda
():
self
.
devices
.
append
(
'gpu'
)
self
.
iter
=
3
self
.
learning_rate
=
0.5
self
.
dtypes
=
[
'float32'
,
'float64'
]
self
.
theta_size
=
[
4
,
4
]
def
run_backward
(
self
,
device
,
dtype
,
mode
):
def
train
(
self
,
device
,
dtype
,
mode
):
paddle
.
set_device
(
device
)
myLayer
=
Optimization_ex1
(
self
.
theta_size
,
dtype
)
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
self
.
learning_rate
,
parameters
=
myLayer
.
parameters
())
loss
=
myLayer
(
mode
)
loss
.
backward
()
for
iter
in
range
(
self
.
iter
):
loss
=
myLayer
(
mode
)
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
def
train_no_clear_grad
(
self
,
device
,
dtype
,
mode
):
paddle
.
set_device
(
device
)
myLayer
=
Optimization_ex1
(
self
.
theta_size
,
dtype
)
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
self
.
learning_rate
,
parameters
=
myLayer
.
parameters
())
for
iter
in
range
(
self
.
iter
):
loss
=
myLayer
(
mode
)
loss
.
backward
()
optimizer
.
step
()
def
test_case_one_step
(
self
):
for
dev
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
self
.
run_backward
(
dev
,
dtype
,
1
)
self
.
train
(
dev
,
dtype
,
1
)
self
.
train_no_clear_grad
(
dev
,
dtype
,
1
)
def
test_case_two_step
(
self
):
for
dev
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
self
.
run_backward
(
dev
,
dtype
,
2
)
self
.
train
(
dev
,
dtype
,
2
)
self
.
train_no_clear_grad
(
dev
,
dtype
,
2
)
def
test_case_non_param
(
self
):
for
dev
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
self
.
run_backward
(
dev
,
dtype
,
3
)
self
.
train
(
dev
,
dtype
,
3
)
self
.
train_no_clear_grad
(
dev
,
dtype
,
3
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录