Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9048229b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9048229b
编写于
8月 28, 2019
作者:
C
chengduo
提交者:
GitHub
8月 28, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry pick] Remove unnecessary op when trainable is false (#19434)
* fix optimizer bug test=develop
上级
5b3d33bd
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
103 addition
and
6 deletion
+103
-6
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+1
-2
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+19
-4
python/paddle/fluid/tests/unittests/test_trainable.py
python/paddle/fluid/tests/unittests/test_trainable.py
+83
-0
未找到文件。
python/paddle/fluid/backward.py
浏览文件 @
9048229b
...
...
@@ -712,8 +712,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
parameters
=
parameter_list
else
:
params
=
program
.
global_block
().
all_parameters
()
program
.
global_block
().
iter_parameters
()
parameters
=
[
param
.
name
for
param
in
params
]
parameters
=
[
param
.
name
for
param
in
params
if
param
.
trainable
]
params_and_grads
=
[]
for
param
in
parameters
:
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
9048229b
...
...
@@ -360,8 +360,9 @@ class Optimizer(object):
global_block
=
framework
.
default_main_program
().
global_block
()
start
=
len
(
global_block
.
ops
)
self
.
helper
=
LayerHelper
(
self
.
__class__
.
__name__
)
self
.
_create_accumulators
(
global_block
,
[
p
[
0
]
for
p
in
parameters_and_grads
])
self
.
_create_accumulators
(
global_block
,
[
p
[
0
]
for
p
in
parameters_and_grads
if
p
[
0
].
trainable
])
self
.
_create_global_learning_rate
()
optimize_ops
=
[]
...
...
@@ -587,6 +588,20 @@ class Optimizer(object):
tuple: (optimize_ops, params_grads) which are, list of operators appended;
and list of (param, grad) Variables pair for optimization.
"""
assert
isinstance
(
loss
,
Variable
),
"The loss should be an Variable."
if
no_grad_set
is
None
:
no_grad_set
=
set
()
elif
isinstance
(
no_grad_set
,
set
)
or
isinstance
(
no_grad_set
,
list
)
or
isinstance
(
no_grad_set
,
tuple
):
no_grad_set
=
set
(
no_grad_set
)
else
:
assert
"no_grad_set should be a set, but the passed type is {}"
.
format
(
type
(
no_grad_set
))
parameters
=
loss
.
block
.
program
.
global_block
().
all_parameters
()
param_no_trainable
=
set
(
[
param
.
name
for
param
in
parameters
if
param
.
trainable
is
False
])
# If the parameter is no trainable, it should not have a gradient.
no_grad_set
.
update
(
param_no_trainable
)
params_grads
=
self
.
backward
(
loss
,
startup_program
=
startup_program
,
...
...
@@ -1390,7 +1405,7 @@ class AdamOptimizer(Optimizer):
assert
isinstance
(
block
,
framework
.
Block
)
main_block
=
block
.
program
.
global_block
()
for
param
,
grad
in
param_and_grads
:
if
grad
is
None
:
if
grad
is
None
or
param
.
trainable
is
False
:
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
"optimizer"
):
...
...
@@ -1553,7 +1568,7 @@ class AdamaxOptimizer(Optimizer):
assert
isinstance
(
block
,
framework
.
Block
)
main_block
=
block
.
program
.
global_block
()
for
param
,
grad
in
parameters_and_grads
:
if
grad
is
None
:
if
grad
is
None
or
param
.
trainable
is
False
:
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
name_scope
(
'adamx'
):
...
...
python/paddle/fluid/tests/unittests/test_trainable.py
0 → 100644
浏览文件 @
9048229b
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
collections
import
Counter
import
unittest
import
paddle.fluid
as
fluid
from
simple_nets
import
init_data
def
test_trainable
():
x
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
feature
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
10
,
param_attr
=
fluid
.
ParamAttr
(
trainable
=
False
))
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
feature
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
return
loss
class
TestTrainable
(
unittest
.
TestCase
):
def
check_trainable
(
self
,
model
,
feed_dict
,
op_count
,
optimizer
=
fluid
.
optimizer
.
Adam
()):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
model
()
optimizer
.
minimize
(
loss
)
# The number of adam should be one.
ops
=
Counter
([
op
.
type
for
op
in
main
.
global_block
().
ops
])
for
op
in
op_count
:
if
op_count
[
op
]
==
0
:
assert
op
not
in
ops
else
:
assert
ops
[
op
]
==
op_count
[
op
]
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
feed
=
feed_dict
)
def
test_trainable
(
self
):
batch_size
=
2
img
,
label
=
init_data
(
batch_size
,
img_shape
=
[
784
],
label_range
=
9
)
feed_dict
=
{
'image'
:
img
,
'label'
:
label
}
# Note that, because the Weight of FC is not trainable and the x is stop_gradient,
# so the 'mul_grad' should not be appended.
self
.
check_trainable
(
test_trainable
,
feed_dict
,
op_count
=
{
'adam'
:
1
,
'scale'
:
2
,
'mul_grad'
:
0
})
self
.
check_trainable
(
test_trainable
,
feed_dict
,
op_count
=
{
'adamax'
:
1
,
'scale'
:
1
,
'mul_grad'
:
0
},
optimizer
=
fluid
.
optimizer
.
Adamax
(
learning_rate
=
0.2
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录