Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
538c8489
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
538c8489
编写于
5年前
作者:
Z
Zhang Ting
提交者:
Tao Luo
5年前
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add decorator skip_check_grad_ci (#21836)
上级
bf9c5de7
develop
2.0.1-rocm-post
Ligoml-patch-1
OliverLPH-patch-1
OliverLPH-patch-2
PaddlePM-patch-1
PaddlePM-patch-2
ZHUI-patch-1
add_default_att
add_model_benchmark_ci
add_some_yaml_config
addfile
all_new_design_exec
ascendrc
ascendrelease
cherry_undefined_var
compile_windows
cp_2.4_fix_numpy
delete_2.0.1-rocm-post
delete_add_default_att
delete_all_new_design_exec
delete_ascendrc
delete_compile_windows
delete_delete_addfile
delete_disable_iterable_dataset_unittest
delete_fix_dataloader_memory_leak
delete_fix_imperative_dygraph_error
delete_fix_retry_ci
delete_fix_undefined_var
delete_improve_sccache
delete_paralleltest
delete_prv-disable-more-cache
delete_revert-31068-fix_conv3d_windows
delete_revert-31562-mean
delete_revert-33630-bug-fix
delete_revert-34159-add_npu_bce_logical_dev
delete_revert-34910-spinlocks_for_allocator
delete_revert-35069-revert-34910-spinlocks_for_allocator
delete_revert-36057-dev/read_flags_in_ut
dingjiaweiww-patch-1
disable_iterable_dataset_unittest
dy2static
enable_eager_model_test
final_state_gen_python_c
final_state_intermediate
fix-numpy-issue
fix_concat_slice
fix_dataloader_memory_leak
fix_dlpack_for
fix_imperative_dygraph_error
fix_npu_ci
fix_op_flops
fix_retry_ci
fix_rnn_docs
fix_tensor_type
fix_undefined_var
fix_var_stop_gradient_error
fixiscan
fixiscan1
fixiscan2
fixiscan3
github/fork/123malin/netifaces
github/fork/123malin/tdm_abacus
github/fork/AshburnLee/dev_unique
github/fork/ForFishes/fix_memory_matmul
github/fork/ForFishes/rm_fluid
github/fork/LielinJiang/move-2.0-api
github/fork/LielinJiang/visual-dl-cb
github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api
github/fork/LiuChiachi/fix-example-code-for-hapi-Model
github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model
github/fork/MrChengmo/fix_ps_profiler
github/fork/MrChengmo/update_ps_heter
github/fork/PWhiddy/patch-1
github/fork/Shixiaowei02/dev/save_load_upgrade
github/fork/TCChenlong/fix_hapi
github/fork/TCChenlong/fix_inden
github/fork/Thunderbrook/xpu_slice
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3
github/fork/XieYunshen/timeout_20S_ut
github/fork/ZeyuChen/remove-nltk
github/fork/arlesniak/arlesniak/selective__mkldnn_flags
github/fork/baiyfbupt/code_doc_mig
github/fork/chalsliu/set_timeout
github/fork/chen-zhiyu/develop
github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error
github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads
github/fork/chenwhql/saveload/add_get_inference_program
github/fork/chenwhql/saveload/remove_save_load_config
github/fork/cryoco/pass-compatibility-trt
github/fork/danleifeng/isempty_api2.0
github/fork/frankwhzhang/api_transfer
github/fork/hbwx24/error_msg/cuda_kernel_error_msg
github/fork/heavengate/cherry_yolo_box
github/fork/heavengate/update_yolo_box
github/fork/iclementine/rnn_fix
github/fork/iducn/testestse
github/fork/jczaja/prv-25537-fix
github/fork/jeff41404/release/1.8
github/fork/jiweibo/api_2.0
github/fork/jiweibo/fix_lite_resnet50_test
github/fork/juncaipeng/fix_doc_1
github/fork/lfchener/sample_code
github/fork/littletomatodonkey/fix_reg_doc
github/fork/liym27/dy2stat_update_assign_to_rc20
github/fork/luotao1/profiler_ut
github/fork/mapingshuo/add_wait
github/fork/mapingshuo/doc_2.0
github/fork/mapingshuo/zero-0.5
github/fork/miraiwk/dev
github/fork/pangyoki/add-Categorical-class-branch
github/fork/pangyoki/add-multinomial-op-branch
github/fork/pangyoki/fix-test_distritbution-CI
github/fork/qjing666/doublegrad
github/fork/qjing666/fix_hdfs_download
github/fork/sandyhouse/add_gather_etc
github/fork/sandyhouse/add_send_recv_alltoall_etc
github/fork/sandyhouse/pipeline_exe_run
github/fork/seiriosPlus/feature/large_scale_kv_save_delta
github/fork/seiriosPlus/fix/paddle_errors_fix
github/fork/seiriosPlus/fix/paddle_op_errors
github/fork/shangzhizhou/fix_test_activation_op_random_bug
github/fork/smallv0221/yxp0924
github/fork/smallv0221/yxp0925
github/fork/swtkiwi/del-matplotlib
github/fork/tianshuo78520a/kunlun_test
github/fork/tianshuo78520a/update_dockerfile
github/fork/wanghaoshuang/bert_fuse
github/fork/wanghaoshuang/label_smooth
github/fork/wanghuancoder/develop_CUDASynchronize
github/fork/wanghuancoder/develop_Layer_doc
github/fork/wanghuancoder/develop_ParameterList_doc
github/fork/wanghuancoder/develop_Sequential_doc
github/fork/wanghuancoder/develop_bilinear_tensor_product
github/fork/wanghuancoder/develop_coverage_build_sh
github/fork/wanghuancoder/develop_in_dynamic_mode_doc
github/fork/wanghuancoder/develop_unique_name_doc
github/fork/wangxicoding/fleet_meta_combine
github/fork/wawltor/error_message_fix_5
github/fork/willthefrog/remove_l2_norm
github/fork/windstamp/momentum_op
github/fork/windstamp/mv_op_5
github/fork/windstamp/normal_api
github/fork/wojtuss/wojtuss/fusion_gru_quantization
github/fork/wojtuss/wojtuss/quantization-with-shift
github/fork/wzzju/fix_err_info
github/fork/wzzju/pure_fp16
github/fork/xiemoyuan/op_error_message
github/fork/xiemoyuan/optimize_error_message
github/fork/yaoxuefeng6/fix_doc
github/fork/yaoxuefeng6/mod_dataset_v2
github/fork/yongqiangma/lod
github/fork/ysh329/fix-clip-by-norm-error
github/fork/ysh329/fix-error-clip-by-value
github/fork/yukavio/error_info
github/fork/zhangting2020/conv_filter_grad
github/fork/zhangting2020/is_compile_with_cuda
github/fork/zhangting2020/place_doc
github/fork/zhangting2020/program
github/fork/zhhsplendid/fix_any
github/fork/zhhsplendid/refine_api2
github/fork/zhhsplendid/refine_api2_test
github/fork/zhhsplendid/refine_api_test_ptb_lm
github/fork/zhhsplendid/refine_api_test_resnet
github/fork/zhhsplendid/refine_api_test_simnet
github/fork/zhiqiu/dev/refine_initializer
github/fork/zhiqiu/dev/remove_inplace_argument
github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11
improve_sccache
incubate/frl_train_eval
incubate/infrt
inplace_addto
layer_norm
make_flag_adding_easier
matmul_double_grad
move_embedding_to_phi
move_histogram_to_pten
move_sgd_to_phi
move_slice_to_pten
move_temporal_shift_to_phi
move_yolo_box_to_phi
npu_fix_alloc
numel
paralleltest
preln_ernie
prv-disable-more-cache
prv-md-even-more
prv-onednn-2.5
prv-reshape-mkldnn-ut2
pten_tensor_refactor
release/1.7
release/1.8
release/2.0
release/2.0-alpha
release/2.0-beta
release/2.0-rc
release/2.0-rc1
release/2.1
release/2.2
release/2.3
release/2.3-fc-ernie-fix
release/2.4
revert-24981-add_device_attr_for_regulization
revert-26856-strategy_example2
revert-27520-disable_pr
revert-31068-fix_conv3d_windows
revert-31562-mean
revert-32290-develop-hardlabel
revert-33037-forci
revert-33475-fix_cifar_label_dimension
revert-33630-bug-fix
revert-34159-add_npu_bce_logical_dev
revert-34406-add_copy_from_tensor
revert-34910-spinlocks_for_allocator
revert-35069-revert-34910-spinlocks_for_allocator
revert-36057-dev/read_flags_in_ut
revert-36201-refine_fast_threaded_ssa_graph_executor
revert-36985-add_license
revert-37318-refactor_dygraph_to_eager
revert-37926-eager_coreops_500
revert-37956-revert-37727-pylayer_support_tuple
revert-38100-mingdong
revert-38301-allocation_rearrange_pr
revert-38703-numpy_bf16_package_reupload
revert-38732-remove_useless_header_in_elementwise_mul_grad
revert-38959-Reduce_Grad
revert-39143-adjust_empty
revert-39227-move_trace_op_to_pten
revert-39268-dev/remove_concat_fluid_kernel
revert-40170-support_partial_grad
revert-41056-revert-40727-move_some_activaion_to_phi
revert-41065-revert-40993-mv_ele_floordiv_pow
revert-41068-revert-40790-phi_new
revert-41944-smaller_inference_api_test
revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator
revert-43155-fix_ut_tempfile
revert-43882-revert-41944-smaller_inference_api_test
revert-45808-phi/simplify_size_op
revert-46827-deform_comment
revert-47325-remove_cudnn_hardcode
revert-47645-add_npu_storage_dims
revert-48815-set_free_when_no_cache_hit_default_value_true
revert-49654-prim_api_gen
revert-49763-fix_static_composite_gen
rocm_dev_0217
support-0D-sort
support_weight_transpose
test_benchmark_ci
test_feature_precision_test_c
test_for_Filtetfiles
test_model_benchmark
test_model_benchmark_ci
zhiqiu-patch-1
v2.4.1
v2.4.0
v2.4.0-rc0
v2.3.2
v2.3.1
v2.3.0
v2.3.0-rc0
v2.2.2
v2.2.1
v2.2.0
v2.2.0-rc0
v2.2.0-bak0
v2.1.3
v2.1.2
v2.1.1
v2.1.0
v2.1.0-rc0
v2.0.2
v2.0.1
v2.0.0
v2.0.0-rc1
v2.0.0-rc0
v2.0.0-beta0
v2.0.0-alpha0
v1.8.5
v1.8.4
v1.8.3
v1.8.2
v1.8.1
v1.8.0
v1.7.2
v1.7.1
v1.7.0
无相关合并请求
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
59 addition
and
59 deletion
+59
-59
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+25
-1
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+7
-1
python/paddle/fluid/tests/unittests/test_executor_return_tensor_not_overwriting.py
.../unittests/test_executor_return_tensor_not_overwriting.py
+2
-1
python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py
...addle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py
+3
-1
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
+9
-11
python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
...n/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
+9
-11
python/paddle/fluid/tests/unittests/test_seq_pool.py
python/paddle/fluid/tests/unittests/test_seq_pool.py
+4
-2
python/paddle/fluid/tests/unittests/white_list/op_check_grad_white_list.py
...id/tests/unittests/white_list/op_check_grad_white_list.py
+0
-31
未找到文件。
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
538c8489
...
...
@@ -146,6 +146,30 @@ def get_numeric_gradient(place,
return
gradient_flat
.
reshape
(
tensor_to_check
.
shape
())
def
skip_check_grad_ci
(
reason
=
None
):
"""Decorator to skip check_grad CI.
Check_grad is required for Op test cases. However, there are some special
cases that do not need to do check_grad. This decorator is used to skip the
check_grad of the above cases.
Note: the execution of unit test will not be skipped. It just avoids check_grad
checking in tearDownClass method by setting a `no_need_check_grad` flag.
Example:
@skip_check_grad_ci(reason="For inference, check_grad is not required.")
class TestInference(OpTest):
"""
if
not
isinstance
(
reason
,
str
):
raise
AssertionError
(
"The reason for skipping check_grad is required."
)
def
wrapper
(
cls
):
cls
.
no_need_check_grad
=
True
return
cls
return
wrapper
class
OpTest
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -182,7 +206,7 @@ class OpTest(unittest.TestCase):
+
OpTest
.
op_type
+
" Op."
)
# cases and ops do no need check_grad
if
cls
.
__name__
in
op_check_grad_white_list
.
NO_NEED_CHECK_GRAD_CASES
\
if
hasattr
(
cls
,
"no_need_check_grad"
)
\
or
cls
.
op_type
in
op_check_grad_white_list
.
EMPTY_GRAD_OP_LIST
:
return
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
538c8489
...
...
@@ -17,7 +17,7 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
from
op_test
import
OpTest
,
skip_check_grad_ci
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
...
...
@@ -61,6 +61,7 @@ class TestDropoutOp3(TestDropoutOp):
}
@
skip_check_grad_ci
(
reason
=
"For inference, check_grad is not required."
)
class
TestDropoutOp4
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
@@ -74,6 +75,7 @@ class TestDropoutOp4(OpTest):
self
.
check_output
()
@
skip_check_grad_ci
(
reason
=
"For inference, check_grad is not required."
)
class
TestDropoutOp5
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
@@ -119,6 +121,7 @@ class TestDropoutOp7(TestDropoutOp):
}
@
skip_check_grad_ci
(
reason
=
"For inference, check_grad is not required."
)
class
TestDropoutOp8
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
@@ -135,6 +138,7 @@ class TestDropoutOp8(OpTest):
self
.
check_output
()
@
skip_check_grad_ci
(
reason
=
"For inference, check_grad is not required."
)
class
TestDropoutOp9
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
@@ -174,6 +178,7 @@ class TestDropoutOpWithSeed(OpTest):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
op_support_gpu
(
"dropout"
),
"core is not compiled with CUDA or core is not support dropout"
)
@
skip_check_grad_ci
(
reason
=
"For inference, check_grad is not required."
)
class
TestFP16DropoutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
@@ -201,6 +206,7 @@ class TestFP16DropoutOp(OpTest):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
op_support_gpu
(
"dropout"
),
"core is not compiled with CUDA or core is not support dropout"
)
@
skip_check_grad_ci
(
reason
=
"For inference, check_grad is not required."
)
class
TestFP16DropoutOp2
(
TestFP16DropoutOp
):
def
init_test_case
(
self
):
self
.
input_size
=
[
32
,
64
,
3
]
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_executor_return_tensor_not_overwriting.py
浏览文件 @
538c8489
...
...
@@ -17,9 +17,10 @@ import unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
from
op_test
import
OpTest
from
op_test
import
OpTest
,
skip_check_grad_ci
@
skip_check_grad_ci
(
reason
=
"Not op test but call the method of class OpTest."
)
class
TestExecutorReturnTensorNotOverwritingWithOptest
(
OpTest
):
def
setUp
(
self
):
pass
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_fused_emb_seq_pool_op.py
浏览文件 @
538c8489
...
...
@@ -17,7 +17,7 @@ from __future__ import print_function
import
unittest
import
platform
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
,
skip_check_grad_ci
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
from
paddle.fluid.op
import
Operator
...
...
@@ -25,6 +25,8 @@ import paddle.compat as cpt
import
paddle.version
as
ver
@
skip_check_grad_ci
(
reason
=
"check_grad is called when ver.mkl() == ON"
"and 'Linux' in platform.platform()."
)
class
TestFusedEmbeddingSeqPoolOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fused_embedding_seq_pool"
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
浏览文件 @
538c8489
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
,
skip_check_grad_ci
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
import
paddle.compat
as
cpt
...
...
@@ -56,6 +56,10 @@ class TestLookupTableOpWithTensorIds(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
@
skip_check_grad_ci
(
reason
=
"Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here."
)
class
TestLookupTableOpWithPadding
(
TestLookupTableOp
):
def
test_check_output
(
self
):
ids
=
np
.
squeeze
(
self
.
inputs
[
'Ids'
])
...
...
@@ -64,12 +68,11 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
self
.
attrs
=
{
'padding_idx'
:
int
(
padding_idx
)}
self
.
check_output
()
def
test_check_grad
(
self
):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
@
skip_check_grad_ci
(
reason
=
"Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here."
)
class
TestLookupTableOpWithTensorIdsAndPadding
(
TestLookupTableOpWithTensorIds
):
def
test_check_output
(
self
):
ids
=
self
.
inputs
[
'Ids'
]
...
...
@@ -79,11 +82,6 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
self
.
attrs
=
{
'padding_idx'
:
cpt
.
long_type
(
padding_idx
)}
self
.
check_output
()
def
test_check_grad
(
self
):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
class
TestLookupTableWIsSelectedRows
(
unittest
.
TestCase
):
def
prepare_ids
(
self
,
scope
,
place
):
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
浏览文件 @
538c8489
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
,
skip_check_grad_ci
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
from
paddle.fluid.op
import
Operator
...
...
@@ -55,6 +55,10 @@ class TestLookupTableOpWithTensorIds(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
@
skip_check_grad_ci
(
reason
=
"Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here."
)
class
TestLookupTableOpWithPadding
(
TestLookupTableOp
):
def
test_check_output
(
self
):
ids
=
np
.
squeeze
(
self
.
inputs
[
'Ids'
])
...
...
@@ -63,12 +67,11 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
self
.
attrs
=
{
'padding_idx'
:
int
(
padding_idx
)}
self
.
check_output
()
def
test_check_grad
(
self
):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
@
skip_check_grad_ci
(
reason
=
"Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here."
)
class
TestLookupTableOpWithTensorIdsAndPadding
(
TestLookupTableOpWithTensorIds
):
def
test_check_output
(
self
):
ids
=
self
.
inputs
[
'Ids'
]
...
...
@@ -78,11 +81,6 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
self
.
attrs
=
{
'padding_idx'
:
cpt
.
long_type
(
padding_idx
)}
self
.
check_output
()
def
test_check_grad
(
self
):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
class
TestLookupTableWIsSelectedRows
(
unittest
.
TestCase
):
def
prepare_ids
(
self
,
scope
,
place
):
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_seq_pool.py
浏览文件 @
538c8489
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
,
skip_check_grad_ci
from
test_reorder_lod_tensor
import
convert_to_offset
...
...
@@ -355,6 +355,8 @@ class TestSeqMaxPool2DLen0LoDLevel2(TestSeqMaxPool2D):
return
[[
1
,
0
,
2
,
2
],
[
0
,
3
,
0
,
10
,
0
]]
@
skip_check_grad_ci
(
reason
=
"Grad computation does not apply to Sequence MAX "
"Pool executed when is_test is true."
)
class
TestSeqMaxPool2DInference
(
TestSeqMaxPool2D
):
def
compute
(
self
,
x
,
offset
,
out
):
self
.
attrs
=
{
"pad_value"
:
1.0
,
'pooltype'
:
"MAX"
,
'is_test'
:
True
}
...
...
@@ -368,7 +370,7 @@ class TestSeqMaxPool2DInference(TestSeqMaxPool2D):
out
[
i
]
=
np
.
reshape
(
np
.
amax
(
sub_x
,
axis
=
0
),
(
3
,
11
))
def
test_check_grad
(
self
):
"""Grad computation does not apply to Sequence MAX
"""Grad computation does not apply to Sequence MAX
Pool executed when is_test is true """
return
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/white_list/op_check_grad_white_list.py
浏览文件 @
538c8489
...
...
@@ -53,34 +53,3 @@ EMPTY_GRAD_OP_LIST = [
'hash'
,
'less_than'
,
'not_equal'
,
'eye'
,
'chunk_eval'
,
'is_empty'
,
'proximal_gd'
,
'collect_fpn_proposals'
,
'unique_with_counts'
]
# Special cases do not need to check grad
NO_NEED_CHECK_GRAD_CASES
=
[
'TestLookupTableOpWithPadding'
,
'TestLookupTableOpWithTensorIdsAndPadding'
,
'TestLookupTableOpWithPadding'
,
'TestLookupTableOpWithTensorIdsAndPadding'
,
'TestSeqMaxPool2DInference'
,
'TestSeqMaxPool2DInferenceLen0'
,
'TestSeqMaxPool2DInferenceLen0LoDLevel2'
,
'TestDropoutOp4'
,
'TestDropoutOp5'
,
'TestDropoutOp8'
,
'TestDropoutOp9'
,
'TestFP16DropoutOp'
,
'TestFP16DropoutOp2'
,
'TestExpandOpBoolean'
,
'TestFusedEmbeddingSeqPoolOp'
,
'TestMKLDNNConcatOp'
,
'TestMKLDNNConcatOp'
,
'TestMKLDNNConcatOp3'
,
'TestElementwiseMulMKLDNNOp_Integrated_With_Convs'
,
'TestConv2dTransposeMKLDNNOp'
,
'TestMKLDNNFuseBias'
,
'TestMKLDNNWithPad'
,
'TestMKLDNNWithStride'
,
'TestMKLDNNWithAsymPad'
,
'TestMKLDNNWithSamePad'
,
'TestMKLDNNWithValidPad'
,
'TestMKLDNNWithValidPad_NHWC'
,
]
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部