Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4e62af80
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4e62af80
编写于
3年前
作者:
C
cc
提交者:
GitHub
3年前
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add FP16 PRelu (#35532)
上级
afd1b372
develop
1.8.5
2.4.1
Ligoml-patch-1
ZHUI-patch-1
add_kylinv10
add_some_yaml_config
bugfix-eval-frame-leakgae
cherry-pick-fix-customOP-random-fail
cherry_undefined_var
cp_2.4_fix_numpy
delete_disable_iterable_dataset_unittest
delete_fix_undefined_var
delete_revert-36057-dev/read_flags_in_ut
dingjiaweiww-patch-1
disable_iterable_dataset_unittest
dy2static
enable_eager_model_test
final_state_gen_python_c
final_state_intermediate
fix-numpy-issue
fix-run-program-grad-node-mem
fix_check
fix_concat_slice
fix_custom_device_copy_sync
fix_dlpack_for
fix_newexe_gc
fix_op_flops
fix_rnn_docs
fix_tensor_type
fix_undefined_var
fix_var_stop_gradient_error
hack_event
incuabte/new_frl
incubate/frl_train_eval
incubate/infrt
incubate/new_frl
incubate/new_frl_rc
incubate/stride
inplace_addto
layer_norm
make_flag_adding_easier
matmul_double_grad
move_embedding_to_phi
move_histogram_to_pten
move_sgd_to_phi
move_slice_to_pten
move_temporal_shift_to_phi
move_yolo_box_to_phi
npu_fix_alloc
operator_opt
pass-compile-eval-frame
preln_ernie
prv-md-even-more
prv-onednn-2.5
prv-reshape-mkldnn-ut2
pten_tensor_refactor
release-deleted/2.5
release-rc/2.5
release/2.2
release/2.3
release/2.3-fc-ernie-fix
release/2.4
release/2.5
release/llm_2.5
revert-36057-dev/read_flags_in_ut
revert-36201-refine_fast_threaded_ssa_graph_executor
revert-36985-add_license
revert-37318-refactor_dygraph_to_eager
revert-37926-eager_coreops_500
revert-37956-revert-37727-pylayer_support_tuple
revert-38100-mingdong
revert-38301-allocation_rearrange_pr
revert-38703-numpy_bf16_package_reupload
revert-38732-remove_useless_header_in_elementwise_mul_grad
revert-38959-Reduce_Grad
revert-39143-adjust_empty
revert-39227-move_trace_op_to_pten
revert-39268-dev/remove_concat_fluid_kernel
revert-40170-support_partial_grad
revert-41056-revert-40727-move_some_activaion_to_phi
revert-41065-revert-40993-mv_ele_floordiv_pow
revert-41068-revert-40790-phi_new
revert-41944-smaller_inference_api_test
revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator
revert-43155-fix_ut_tempfile
revert-43882-revert-41944-smaller_inference_api_test
revert-45808-phi/simplify_size_op
revert-46827-deform_comment
revert-47325-remove_cudnn_hardcode
revert-47645-add_npu_storage_dims
revert-48815-set_free_when_no_cache_hit_default_value_true
revert-49499-test_ninja_on_ci
revert-49654-prim_api_gen
revert-49673-modify_get_single_cov
revert-49763-fix_static_composite_gen
revert-50158-fix_found_inf_bug_for_custom_optimizer
revert-50188-refine_optimizer_create_accumulators
revert-50335-fix_optminizer_set_auxiliary_var_bug
revert-51676-flag_delete
revert-51850-fix_softmaxce_dev
revert-52175-dev_peak_memory
revert-52186-deve
revert-52523-test_py38
revert-52912-develop
revert-53248-set_cmake_policy
revert-54029-fix_windows_compile_bug
revert-54068-support_translating_op_attribute
revert-54214-modify_cmake_dependencies
revert-54370-offline_pslib
revert-54391-fix_cmake_md5error
revert-54411-fix_cpp17_compile
revert-54466-offline_pslib
revert-54480-cmake-rocksdb
revert-55568-fix_BF16_bug1
revert-56328-new_ir_support_vector_type_place_transfer
revert-56366-fix_openssl_bug
revert-56545-revert-56366-fix_openssl_bug
revert-56620-fix_new_ir_ocr_bug
revert-56925-check_inputs_grad_semantic
revert-57005-refine_stride_flag
sd_conv_linear_autocast
semi-auto/rule-base
support-0D-sort
support_weight_transpose
test_for_Filtetfiles
zhiqiu-patch-1
v2.5.1
v2.5.0
v2.5.0-rc1
v2.5.0-rc0
v2.4.2
v2.4.1
v2.4.0
v2.4.0-rc0
v2.3.2
v2.3.1
v2.3.0
v2.3.0-rc0
v2.2.2
v2.2.1
v2.2.0
v2.2.0-rc0
v2.2.0-bak0
无相关合并请求
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
66 addition
and
10 deletion
+66
-10
paddle/fluid/operators/math/prelu.cu
paddle/fluid/operators/math/prelu.cu
+9
-3
paddle/fluid/operators/prelu_op.cu
paddle/fluid/operators/prelu_op.cu
+13
-6
python/paddle/fluid/tests/unittests/test_prelu_op.py
python/paddle/fluid/tests/unittests/test_prelu_op.py
+44
-1
未找到文件。
paddle/fluid/operators/math/prelu.cu
浏览文件 @
4e62af80
...
...
@@ -33,7 +33,8 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
size_t
channel_index
=
temp
%
channel_num
;
T
scale
=
alpha
[
channel_index
];
T
x
=
input
[
index
];
output
[
index
]
=
(
x
>
0
)
?
x
:
scale
*
x
;
T
zero
=
static_cast
<
T
>
(
0
);
output
[
index
]
=
(
x
>
zero
)
?
x
:
scale
*
x
;
}
}
...
...
@@ -45,7 +46,8 @@ __global__ void PReluElementWiseKernel(const T *input, const T *alpha,
size_t
element_index
=
index
%
spatial_size
;
T
scale
=
alpha
[
element_index
];
T
x
=
input
[
index
];
output
[
index
]
=
(
x
>
0
)
?
x
:
scale
*
x
;
T
zero
=
static_cast
<
T
>
(
0
);
output
[
index
]
=
(
x
>
zero
)
?
x
:
scale
*
x
;
}
}
...
...
@@ -55,7 +57,8 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
T
scale
=
alpha
[
0
];
CUDA_KERNEL_LOOP
(
index
,
numel
)
{
T
x
=
input
[
index
];
output
[
index
]
=
(
x
>
0
)
?
x
:
scale
*
x
;
T
zero
=
static_cast
<
T
>
(
0
);
output
[
index
]
=
(
x
>
zero
)
?
x
:
scale
*
x
;
}
}
...
...
@@ -88,12 +91,15 @@ void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
}
template
class
PreluChannelWiseDirectCUDAFunctor
<
float
>;
template
class
PreluChannelWiseDirectCUDAFunctor
<
paddle
::
platform
::
float16
>;
template
class
PreluChannelWiseDirectCUDAFunctor
<
double
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
float
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
paddle
::
platform
::
float16
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
double
>;
template
class
PreluScalarDirectCUDAFunctor
<
float
>;
template
class
PreluScalarDirectCUDAFunctor
<
paddle
::
platform
::
float16
>;
template
class
PreluScalarDirectCUDAFunctor
<
double
>;
}
// namespace math
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/prelu_op.cu
浏览文件 @
4e62af80
...
...
@@ -87,8 +87,9 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
}
T
x
=
x_ptr
[
index
];
T
dy
=
dy_ptr
[
index
];
if
(
dx_ptr
!=
nullptr
)
dx_ptr
[
index
]
=
(
x
>
0
)
?
dy
:
scale
*
dy
;
if
(
dalpha_ptr
!=
nullptr
)
dalpha_ptr
[
index
]
=
(
x
>
0
)
?
0
:
x
*
dy
;
T
zero
=
static_cast
<
T
>
(
0
);
if
(
dx_ptr
!=
nullptr
)
dx_ptr
[
index
]
=
(
x
>
zero
)
?
dy
:
scale
*
dy
;
if
(
dalpha_ptr
!=
nullptr
)
dalpha_ptr
[
index
]
=
(
x
>
zero
)
?
zero
:
x
*
dy
;
}
}
...
...
@@ -112,9 +113,11 @@ class PreluOpGradFunctor {
}
};
template
<
typename
T
>
struct
IdentityFunctor
{
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
;
}
template
<
typename
T
>
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -174,9 +177,9 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
reduce_dims
.
push_back
(
i
);
}
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>
>
(
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
>
(
dalpha_tmp
,
dalpha
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
<
T
>
(),
stream
);
IdentityFunctor
(),
stream
);
}
};
...
...
@@ -184,10 +187,14 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
prelu
,
ops
::
CUDAPReluKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
CUDAPReluKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
CUDAPReluKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
prelu_grad
,
ops
::
CUDAPReluGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
CUDAPReluGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
CUDAPReluGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_prelu_op.py
浏览文件 @
4e62af80
...
...
@@ -153,11 +153,12 @@ class TestNNPReluAPI(unittest.TestCase):
class
PReluTest
(
OpTest
):
def
setUp
(
self
):
self
.
init_dtype
()
self
.
init_input_shape
()
self
.
init_attr
()
self
.
op_type
=
"prelu"
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
x_shape
)
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
x_shape
)
.
astype
(
self
.
dtype
)
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
x_np
[
np
.
abs
(
x_np
)
<
0.005
]
=
0.02
...
...
@@ -168,6 +169,7 @@ class PReluTest(OpTest):
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
,
self
.
x_shape
[
1
],
1
,
1
])
else
:
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
]
+
self
.
x_shape
[
1
:])
alpha_np
=
alpha_np
.
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
x_np
,
'Alpha'
:
alpha_np
}
...
...
@@ -184,6 +186,9 @@ class PReluTest(OpTest):
assert
out_np
is
not
self
.
inputs
[
'X'
]
self
.
outputs
=
{
'Out'
:
out_np
}
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float64
def
init_input_shape
(
self
):
self
.
x_shape
=
[
2
,
100
,
3
,
4
]
...
...
@@ -270,6 +275,44 @@ class TestModeElementRank6(PReluTest):
self
.
attrs
=
{
'mode'
:
"element"
}
def
create_test_fp16_class
(
parent
,
check_grad
=
True
,
atol
=
1e-3
,
max_relative_error
=
0.05
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestPReluFp16Case
(
parent
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
atol
)
def
test_check_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
)
and
check_grad
:
self
.
check_grad_with_place
(
place
,
[
'X'
,
'Alpha'
],
'Out'
,
max_relative_error
=
max_relative_error
)
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"Fp16Op"
)
TestPReluFp16Case
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestPReluFp16Case
create_test_fp16_class
(
TestModeElt
)
create_test_fp16_class
(
TestModeAllRank3
)
create_test_fp16_class
(
TestModeAllRank6
)
create_test_fp16_class
(
TestModeChannelRank3
)
create_test_fp16_class
(
TestModeChannelRank6
)
create_test_fp16_class
(
TestModeElementRank3
)
create_test_fp16_class
(
TestModeElementRank6
)
def
prelu_t
(
x
,
mode
,
param_attr
=
None
,
name
=
None
):
helper
=
fluid
.
layer_helper
.
LayerHelper
(
'prelu'
,
**
locals
())
alpha_shape
=
[
1
,
x
.
shape
[
1
],
1
,
1
]
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部