Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e0d8c6ac
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e0d8c6ac
编写于
5年前
作者:
C
chengduo
提交者:
GitHub
5年前
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add find_no_grad_vars in backward.py (#17942)
* add not_been_used_vars to no_grad_set test=develop
上级
449c7a9f
develop
2.0.1-rocm-post
Ligoml-patch-1
OliverLPH-patch-1
OliverLPH-patch-2
PaddlePM-patch-1
PaddlePM-patch-2
ZHUI-patch-1
add_default_att
add_model_benchmark_ci
add_some_yaml_config
addfile
all_new_design_exec
ascendrc
ascendrelease
cherry_undefined_var
compile_windows
delete_2.0.1-rocm-post
delete_add_default_att
delete_all_new_design_exec
delete_ascendrc
delete_compile_windows
delete_delete_addfile
delete_disable_iterable_dataset_unittest
delete_fix_dataloader_memory_leak
delete_fix_imperative_dygraph_error
delete_fix_retry_ci
delete_fix_undefined_var
delete_improve_sccache
delete_paddle_tiny_install
delete_paralleltest
delete_prv-disable-more-cache
delete_revert-31068-fix_conv3d_windows
delete_revert-31562-mean
delete_revert-33630-bug-fix
delete_revert-34159-add_npu_bce_logical_dev
delete_revert-34910-spinlocks_for_allocator
delete_revert-35069-revert-34910-spinlocks_for_allocator
delete_revert-36057-dev/read_flags_in_ut
dingjiaweiww-patch-1
disable_iterable_dataset_unittest
dy2static
enable_eager_model_test
final_state_gen_python_c
final_state_intermediate
fix-numpy-issue
fix_concat_slice
fix_dataloader_memory_leak
fix_imperative_dygraph_error
fix_npu_ci
fix_op_flops
fix_retry_ci
fix_rnn_docs
fix_tensor_type
fix_undefined_var
fixiscan
fixiscan1
fixiscan2
fixiscan3
github/fork/123malin/netifaces
github/fork/123malin/tdm_abacus
github/fork/AshburnLee/dev_unique
github/fork/ForFishes/fix_memory_matmul
github/fork/ForFishes/rm_fluid
github/fork/LielinJiang/move-2.0-api
github/fork/LielinJiang/visual-dl-cb
github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api
github/fork/LiuChiachi/fix-example-code-for-hapi-Model
github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model
github/fork/MrChengmo/fix_ps_profiler
github/fork/MrChengmo/update_ps_heter
github/fork/PWhiddy/patch-1
github/fork/Shixiaowei02/dev/save_load_upgrade
github/fork/TCChenlong/fix_hapi
github/fork/TCChenlong/fix_inden
github/fork/Thunderbrook/xpu_slice
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3
github/fork/XieYunshen/timeout_20S_ut
github/fork/ZeyuChen/remove-nltk
github/fork/arlesniak/arlesniak/selective__mkldnn_flags
github/fork/baiyfbupt/code_doc_mig
github/fork/chalsliu/set_timeout
github/fork/chen-zhiyu/develop
github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error
github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads
github/fork/chenwhql/saveload/add_get_inference_program
github/fork/chenwhql/saveload/remove_save_load_config
github/fork/cryoco/pass-compatibility-trt
github/fork/danleifeng/isempty_api2.0
github/fork/frankwhzhang/api_transfer
github/fork/hbwx24/error_msg/cuda_kernel_error_msg
github/fork/heavengate/cherry_yolo_box
github/fork/heavengate/update_yolo_box
github/fork/iclementine/rnn_fix
github/fork/iducn/testestse
github/fork/jczaja/prv-25537-fix
github/fork/jeff41404/release/1.8
github/fork/jiweibo/api_2.0
github/fork/jiweibo/fix_lite_resnet50_test
github/fork/juncaipeng/fix_doc_1
github/fork/lfchener/sample_code
github/fork/littletomatodonkey/fix_reg_doc
github/fork/liym27/dy2stat_update_assign_to_rc20
github/fork/luotao1/profiler_ut
github/fork/mapingshuo/add_wait
github/fork/mapingshuo/doc_2.0
github/fork/mapingshuo/zero-0.5
github/fork/miraiwk/dev
github/fork/pangyoki/add-Categorical-class-branch
github/fork/pangyoki/add-multinomial-op-branch
github/fork/pangyoki/fix-test_distritbution-CI
github/fork/qjing666/doublegrad
github/fork/qjing666/fix_hdfs_download
github/fork/sandyhouse/add_gather_etc
github/fork/sandyhouse/add_send_recv_alltoall_etc
github/fork/sandyhouse/pipeline_exe_run
github/fork/seiriosPlus/feature/large_scale_kv_save_delta
github/fork/seiriosPlus/fix/paddle_errors_fix
github/fork/seiriosPlus/fix/paddle_op_errors
github/fork/shangzhizhou/fix_test_activation_op_random_bug
github/fork/smallv0221/yxp0924
github/fork/smallv0221/yxp0925
github/fork/swtkiwi/del-matplotlib
github/fork/tianshuo78520a/kunlun_test
github/fork/tianshuo78520a/update_dockerfile
github/fork/wanghaoshuang/bert_fuse
github/fork/wanghaoshuang/label_smooth
github/fork/wanghuancoder/develop_CUDASynchronize
github/fork/wanghuancoder/develop_Layer_doc
github/fork/wanghuancoder/develop_ParameterList_doc
github/fork/wanghuancoder/develop_Sequential_doc
github/fork/wanghuancoder/develop_bilinear_tensor_product
github/fork/wanghuancoder/develop_coverage_build_sh
github/fork/wanghuancoder/develop_in_dynamic_mode_doc
github/fork/wanghuancoder/develop_unique_name_doc
github/fork/wangxicoding/fleet_meta_combine
github/fork/wawltor/error_message_fix_5
github/fork/willthefrog/remove_l2_norm
github/fork/windstamp/momentum_op
github/fork/windstamp/mv_op_5
github/fork/windstamp/normal_api
github/fork/wojtuss/wojtuss/fusion_gru_quantization
github/fork/wojtuss/wojtuss/quantization-with-shift
github/fork/wzzju/fix_err_info
github/fork/wzzju/pure_fp16
github/fork/xiemoyuan/op_error_message
github/fork/xiemoyuan/optimize_error_message
github/fork/yaoxuefeng6/fix_doc
github/fork/yaoxuefeng6/mod_dataset_v2
github/fork/yongqiangma/lod
github/fork/ysh329/fix-clip-by-norm-error
github/fork/ysh329/fix-error-clip-by-value
github/fork/yukavio/error_info
github/fork/zhangting2020/conv_filter_grad
github/fork/zhangting2020/is_compile_with_cuda
github/fork/zhangting2020/place_doc
github/fork/zhangting2020/program
github/fork/zhhsplendid/fix_any
github/fork/zhhsplendid/refine_api2
github/fork/zhhsplendid/refine_api2_test
github/fork/zhhsplendid/refine_api_test_ptb_lm
github/fork/zhhsplendid/refine_api_test_resnet
github/fork/zhhsplendid/refine_api_test_simnet
github/fork/zhiqiu/dev/refine_initializer
github/fork/zhiqiu/dev/remove_inplace_argument
github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11
improve_sccache
incubate/infrt
inplace_addto
make_flag_adding_easier
move_embedding_to_phi
move_histogram_to_pten
move_sgd_to_phi
move_slice_to_pten
move_temporal_shift_to_phi
move_yolo_box_to_phi
npu_fix_alloc
numel
paddle_tiny_install
paralleltest
preln_ernie
prv-disable-more-cache
prv-md-even-more
prv-onednn-2.5
pten_tensor_refactor
release/1.6
release/1.7
release/1.8
release/2.0
release/2.0-alpha
release/2.0-beta
release/2.0-rc
release/2.0-rc1
release/2.1
release/2.2
release/2.3
release/2.3-fc-ernie-fix
release/2.4
revert-24981-add_device_attr_for_regulization
revert-26856-strategy_example2
revert-27520-disable_pr
revert-31068-fix_conv3d_windows
revert-31562-mean
revert-32290-develop-hardlabel
revert-33037-forci
revert-33475-fix_cifar_label_dimension
revert-33630-bug-fix
revert-34159-add_npu_bce_logical_dev
revert-34406-add_copy_from_tensor
revert-34910-spinlocks_for_allocator
revert-35069-revert-34910-spinlocks_for_allocator
revert-36057-dev/read_flags_in_ut
revert-36201-refine_fast_threaded_ssa_graph_executor
revert-36985-add_license
revert-37318-refactor_dygraph_to_eager
revert-37926-eager_coreops_500
revert-37956-revert-37727-pylayer_support_tuple
revert-38100-mingdong
revert-38301-allocation_rearrange_pr
revert-38703-numpy_bf16_package_reupload
revert-38732-remove_useless_header_in_elementwise_mul_grad
revert-38959-Reduce_Grad
revert-39143-adjust_empty
revert-39227-move_trace_op_to_pten
revert-39268-dev/remove_concat_fluid_kernel
revert-40170-support_partial_grad
revert-41056-revert-40727-move_some_activaion_to_phi
revert-41065-revert-40993-mv_ele_floordiv_pow
revert-41068-revert-40790-phi_new
revert-41944-smaller_inference_api_test
revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator
revert-43155-fix_ut_tempfile
revert-43882-revert-41944-smaller_inference_api_test
revert-45808-phi/simplify_size_op
revert-46827-deform_comment
rocm_dev_0217
support_weight_transpose
test_benchmark_ci
test_feature_precision_test_c
test_model_benchmark
test_model_benchmark_ci
zhiqiu-patch-1
v2.4.0-rc0
v2.3.2
v2.3.1
v2.3.0
v2.3.0-rc0
v2.2.2
v2.2.1
v2.2.0
v2.2.0-rc0
v2.2.0-bak0
v2.1.3
v2.1.2
v2.1.1
v2.1.0
v2.1.0-rc0
v2.0.2
v2.0.1
v2.0.0
v2.0.0-rc1
v2.0.0-rc0
v2.0.0-beta0
v2.0.0-alpha0
v1.8.5
v1.8.4
v1.8.3
v1.8.2
v1.8.1
v1.8.0
v1.7.2
v1.7.1
v1.7.0
v1.6.3
v1.6.2
v1.6.1
v1.6.0
v1.6.0-rc0
无相关合并请求
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
122 addition
and
9 deletion
+122
-9
paddle/fluid/op_use_default_grad_op_maker.spec
paddle/fluid/op_use_default_grad_op_maker.spec
+0
-1
paddle/fluid/operators/hierarchical_sigmoid_op.cc
paddle/fluid/operators/hierarchical_sigmoid_op.cc
+42
-7
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+23
-1
python/paddle/fluid/tests/unittests/test_backward_find_no_grad_vars.py
.../fluid/tests/unittests/test_backward_find_no_grad_vars.py
+57
-0
未找到文件。
paddle/fluid/op_use_default_grad_op_maker.spec
浏览文件 @
e0d8c6ac
...
...
@@ -15,7 +15,6 @@ fusion_seqexpand_concat_fc
fusion_seqpool_concat
fusion_squared_mat_sub
gru
hierarchical_sigmoid
lrn
lstm_unit
max_pool2d_with_index
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/hierarchical_sigmoid_op.cc
浏览文件 @
e0d8c6ac
...
...
@@ -86,6 +86,10 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
}
};
/*
* Inputs: X, W, Label, PathTable, PathCode, Bias
* Outputs: Out, PreOut, W_out
*/
template
<
typename
AttrType
>
class
HierarchicalSigmoidOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
...
...
@@ -162,6 +166,37 @@ Hierarchical Probabilistic Neural Network Language Model."
}
};
/*
* Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD
* Outputs: X@GRAD, W@GRAD, Bias@GRAD
*/
class
HierarchicalSigmoidGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
this
->
ForwardOpType
()
+
"_grad"
);
// Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
"W"
,
Input
(
"W"
));
op
->
SetInput
(
"Bias"
,
Input
(
"Bias"
));
op
->
SetInput
(
"Label"
,
Input
(
"Label"
));
op
->
SetInput
(
"PathTable"
,
Input
(
"PathTable"
));
op
->
SetInput
(
"PathCode"
,
Input
(
"PathCode"
));
op
->
SetInput
(
"PreOut"
,
Output
(
"PreOut"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
// Outputs: X@GRAD, W@GRAD, Bias@GRAD
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"W"
),
InputGrad
(
"W"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Bias"
),
InputGrad
(
"Bias"
));
op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
class
HierarchicalSigmoidGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -209,17 +244,17 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
auto
attr
=
ctx
->
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
if
(
is_sparse
)
{
VLOG
(
3
0
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to SelectedRows"
;
ctx
->
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
VLOG
(
3
0
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"W"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
w_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
if
(
hasBias
)
{
VLOG
(
3
0
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"Bias"
)
<<
" is set to LoDTensor"
;
VLOG
(
3
)
<<
"hierarchical_sigmoid_grad op "
<<
framework
::
GradVarName
(
"Bias"
)
<<
" is set to LoDTensor"
;
ctx
->
SetType
(
bias_grad_var_name
,
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
ctx
->
SetDataType
(
w_grad_var_name
,
ctx
->
GetDataType
(
ctx
->
Input
(
"W"
)[
0
]));
...
...
@@ -232,7 +267,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
hierarchical_sigmoid
,
ops
::
HierarchicalSigmoidOp
,
ops
::
HierarchicalSigmoidOpMaker
<
int
>
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
ops
::
HierarchicalSigmoidGradMaker
);
REGISTER_OPERATOR
(
hierarchical_sigmoid_grad
,
ops
::
HierarchicalSigmoidGradOp
,
ops
::
HierarchicalSigmoidGradOpGradVarTypeInference
);
REGISTER_OP_CPU_KERNEL
(
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/backward.py
浏览文件 @
e0d8c6ac
...
...
@@ -552,7 +552,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
block_no_grad_set
=
set
(
map
(
_strip_grad_suffix_
,
no_grad_dict
[
0
]))
op_path
=
_find_op_path_
(
root_block
,
[
loss
],
[],
block_no_grad_set
)
no_grad_vars
=
_find_no_grad_vars
(
root_block
,
op_path
,
[
loss
],
block_no_grad_set
)
block_no_grad_set
.
update
(
no_grad_vars
)
no_grad_dict
[
0
].
update
(
list
(
map
(
_append_grad_suffix_
,
block_no_grad_set
)))
input_grad_names_set
=
None
...
...
@@ -630,6 +632,26 @@ def _as_list(x):
return
list
(
x
)
if
isinstance
(
x
,
collections
.
Sequence
)
else
[
x
]
def
_find_no_grad_vars
(
block
,
op_path
,
targets
,
no_grad_set
):
"""
Find the vars which is not used in the program, and
those var belong to no_grad_var.
"""
output_names
=
set
([
out
.
name
for
out
in
targets
])
no_grad_var
=
[]
for
i
,
op
in
reversed
(
list
(
enumerate
(
op_path
))):
# If the op has sub_block, it is too complicated to find the correct no_grad_var.
if
not
op
.
has_attr
(
"sub_block"
):
for
out_var
in
op
.
desc
.
output_arg_names
():
if
out_var
not
in
output_names
and
out_var
not
in
op
.
desc
.
input_arg_names
(
)
and
not
block
.
vars
[
out_var
].
stop_gradient
:
no_grad_var
.
append
(
out_var
)
for
name
in
op
.
desc
.
input_arg_names
():
if
name
not
in
no_grad_set
:
output_names
.
add
(
name
)
return
set
(
no_grad_var
)
def
_find_op_path_
(
block
,
outputs
,
inputs
,
no_grad_set
):
"""
no_grad_set will also be changed
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_backward_find_no_grad_vars.py
0 → 100644
浏览文件 @
e0d8c6ac
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
simple_nets
import
init_data
def
simple_net1
():
x
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
feature
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
20
,
act
=
None
)
part1
,
part2
=
fluid
.
layers
.
split
(
feature
,
num_or_sections
=
[
10
,
10
],
dim
=
1
)
# Note that: part2 is not used.
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
part1
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
return
loss
class
TestBackward
(
unittest
.
TestCase
):
def
check_backward
(
self
,
model
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
batch_size
=
2
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
model
()
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.1
)
optimizer
.
minimize
(
loss
)
exe
.
run
(
fluid
.
default_startup_program
())
img
,
label
=
init_data
(
batch_size
,
img_shape
=
[
784
],
label_range
=
9
)
exe
.
run
(
feed
=
{
'image'
:
img
,
'label'
:
label
})
def
test_backward
(
self
):
self
.
check_backward
(
simple_net1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部