Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3f2a665a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3f2a665a
编写于
11月 30, 2021
作者:
G
Guoxia Wang
提交者:
GitHub
11月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support data_format='NHWC' for prelu channel mode (#37019)
* support data_format='NHWC' for prelu channel mode
上级
0c82e3a0
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
425 addition
and
130 deletion
+425
-130
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
+8
-3
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
+4
-2
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
+15
-7
paddle/fluid/operators/math/prelu.cu
paddle/fluid/operators/math/prelu.cu
+26
-7
paddle/fluid/operators/math/prelu.h
paddle/fluid/operators/math/prelu.h
+2
-1
paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc
paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc
+14
-5
paddle/fluid/operators/prelu_op.cc
paddle/fluid/operators/prelu_op.cc
+30
-6
paddle/fluid/operators/prelu_op.cu
paddle/fluid/operators/prelu_op.cu
+24
-9
paddle/fluid/operators/prelu_op.h
paddle/fluid/operators/prelu_op.h
+46
-22
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+24
-5
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py
...luid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py
+11
-3
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py
...id/tests/unittests/ir/inference/test_trt_convert_prelu.py
+51
-32
python/paddle/fluid/tests/unittests/test_imperative_layers.py
...on/paddle/fluid/tests/unittests/test_imperative_layers.py
+3
-2
python/paddle/fluid/tests/unittests/test_prelu_op.py
python/paddle/fluid/tests/unittests/test_prelu_op.py
+130
-15
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+25
-7
python/paddle/nn/layer/activation.py
python/paddle/nn/layer/activation.py
+12
-4
未找到文件。
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
浏览文件 @
3f2a665a
...
@@ -34,6 +34,11 @@ class PReluOpConverter : public OpConverter {
...
@@ -34,6 +34,11 @@ class PReluOpConverter : public OpConverter {
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
// Get attrs
// Get attrs
std
::
string
mode
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"mode"
));
std
::
string
mode
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"mode"
));
std
::
string
data_format
=
"NCHW"
;
if
(
op_desc
.
HasAttr
(
"data_format"
))
{
data_format
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
.
GetAttr
(
"data_format"
));
}
auto
*
alpha_var
=
scope
.
FindVar
(
op_desc
.
Input
(
"Alpha"
)[
0
]);
auto
*
alpha_var
=
scope
.
FindVar
(
op_desc
.
Input
(
"Alpha"
)[
0
]);
auto
*
alpha_tensor
=
alpha_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
alpha_tensor
=
alpha_var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
@@ -47,7 +52,7 @@ class PReluOpConverter : public OpConverter {
...
@@ -47,7 +52,7 @@ class PReluOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
plugin
::
PReluPluginDynamic
*
plugin
=
new
plugin
::
PReluPluginDynamic
(
plugin
::
PReluPluginDynamic
*
plugin
=
new
plugin
::
PReluPluginDynamic
(
alpha_data
,
alpha_tensor_temp
->
numel
(),
mode
);
alpha_data
,
alpha_tensor_temp
->
numel
(),
mode
,
data_format
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
input_num
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
input_num
,
plugin
);
}
else
{
}
else
{
#if IS_TRT_VERSION_GE(7000)
#if IS_TRT_VERSION_GE(7000)
...
@@ -74,8 +79,8 @@ class PReluOpConverter : public OpConverter {
...
@@ -74,8 +79,8 @@ class PReluOpConverter : public OpConverter {
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ParametricReLU
,
*
input
,
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
ParametricReLU
,
*
input
,
*
alpha_layer_output
);
*
alpha_layer_output
);
#else
#else
plugin
::
PReluPlugin
*
plugin
=
plugin
::
PReluPlugin
*
plugin
=
new
plugin
::
PReluPlugin
(
new
plugin
::
PReluPlugin
(
alpha_data
,
alpha_tensor_temp
->
numel
(),
mode
);
alpha_data
,
alpha_tensor_temp
->
numel
(),
mode
,
data_format
);
layer
=
engine_
->
AddPlugin
(
&
input
,
input_num
,
plugin
);
layer
=
engine_
->
AddPlugin
(
&
input
,
input_num
,
plugin
);
#endif
#endif
}
}
...
...
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
浏览文件 @
3f2a665a
...
@@ -69,10 +69,11 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
...
@@ -69,10 +69,11 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
}
}
if
(
mode_
==
"channel"
)
{
if
(
mode_
==
"channel"
)
{
bool
channel_last
=
data_format_
==
"NHWC"
;
operators
::
math
::
PreluChannelWiseDirectCUDAFunctor
<
float
>
operators
::
math
::
PreluChannelWiseDirectCUDAFunctor
<
float
>
prelu_channel_wise
;
prelu_channel_wise
;
prelu_channel_wise
(
stream
,
input
,
alpha
,
output
,
input_dims
.
d
[
0
],
prelu_channel_wise
(
stream
,
input
,
alpha
,
output
,
input_dims
.
d
[
0
],
input_dims
.
d
[
1
],
numel
);
input_dims
.
d
[
1
],
channel_last
,
numel
);
}
else
if
(
mode_
==
"element"
)
{
}
else
if
(
mode_
==
"element"
)
{
operators
::
math
::
PreluElementWiseDirectCUDAFunctor
<
float
>
operators
::
math
::
PreluElementWiseDirectCUDAFunctor
<
float
>
prelu_element_wise
;
prelu_element_wise
;
...
@@ -168,10 +169,11 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
...
@@ -168,10 +169,11 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
}
}
if
(
mode_
==
"channel"
)
{
if
(
mode_
==
"channel"
)
{
bool
channel_last
=
data_format_
==
"NHWC"
;
operators
::
math
::
PreluChannelWiseDirectCUDAFunctor
<
float
>
operators
::
math
::
PreluChannelWiseDirectCUDAFunctor
<
float
>
prelu_channel_wise
;
prelu_channel_wise
;
prelu_channel_wise
(
stream
,
input
,
alpha
,
output
,
input_dims
.
d
[
0
],
prelu_channel_wise
(
stream
,
input
,
alpha
,
output
,
input_dims
.
d
[
0
],
input_dims
.
d
[
1
],
numel
);
input_dims
.
d
[
1
],
channel_last
,
numel
);
}
else
if
(
mode_
==
"element"
)
{
}
else
if
(
mode_
==
"element"
)
{
operators
::
math
::
PreluElementWiseDirectCUDAFunctor
<
float
>
operators
::
math
::
PreluElementWiseDirectCUDAFunctor
<
float
>
prelu_element_wise
;
prelu_element_wise
;
...
...
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
浏览文件 @
3f2a665a
...
@@ -32,11 +32,12 @@ class PReluPlugin : public PluginTensorRT {
...
@@ -32,11 +32,12 @@ class PReluPlugin : public PluginTensorRT {
std
::
vector
<
float
>
weight_
;
std
::
vector
<
float
>
weight_
;
float
*
p_gpu_weight_
;
float
*
p_gpu_weight_
;
std
::
string
mode_
;
std
::
string
mode_
;
std
::
string
data_format_
;
public:
public:
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
getBaseSerializationSize
()
+
SerializedSize
(
mode_
.
c_str
())
+
return
getBaseSerializationSize
()
+
SerializedSize
(
mode_
.
c_str
())
+
SerializedSize
(
weight_
);
SerializedSize
(
data_format_
.
c_str
())
+
SerializedSize
(
weight_
);
}
}
// TRT will call this func when we need to serialize the configuration of
// TRT will call this func when we need to serialize the configuration of
...
@@ -46,11 +47,12 @@ class PReluPlugin : public PluginTensorRT {
...
@@ -46,11 +47,12 @@ class PReluPlugin : public PluginTensorRT {
serializeBase
(
buffer
);
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
weight_
);
SerializeValue
(
&
buffer
,
weight_
);
SerializeValue
(
&
buffer
,
mode_
.
c_str
());
SerializeValue
(
&
buffer
,
mode_
.
c_str
());
SerializeValue
(
&
buffer
,
data_format_
.
c_str
());
}
}
PReluPlugin
(
const
float
*
weight
,
const
int
weight_num
,
PReluPlugin
(
const
float
*
weight
,
const
int
weight_num
,
std
::
string
const
&
mode
)
std
::
string
const
&
mode
,
std
::
string
const
&
data_format
)
:
mode_
(
mode
)
{
:
mode_
(
mode
)
,
data_format_
(
data_format
)
{
weight_
.
resize
(
weight_num
);
weight_
.
resize
(
weight_num
);
std
::
copy
(
weight
,
weight
+
weight_num
,
weight_
.
data
());
std
::
copy
(
weight
,
weight
+
weight_num
,
weight_
.
data
());
}
}
...
@@ -63,13 +65,17 @@ class PReluPlugin : public PluginTensorRT {
...
@@ -63,13 +65,17 @@ class PReluPlugin : public PluginTensorRT {
const
char
*
prelu_mode
;
const
char
*
prelu_mode
;
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
prelu_mode
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
prelu_mode
);
mode_
=
std
::
string
(
prelu_mode
);
mode_
=
std
::
string
(
prelu_mode
);
const
char
*
prelu_data_format
;
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
prelu_data_format
);
data_format_
=
std
::
string
(
prelu_data_format
);
}
}
~
PReluPlugin
()
{}
~
PReluPlugin
()
{}
int
initialize
()
TRT_NOEXCEPT
override
;
int
initialize
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
PReluPlugin
*
clone
()
const
TRT_NOEXCEPT
override
{
PReluPlugin
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
*
ptr
=
new
PReluPlugin
(
weight_
.
data
(),
weight_
.
size
(),
mode_
);
auto
*
ptr
=
new
PReluPlugin
(
weight_
.
data
(),
weight_
.
size
(),
mode_
,
data_format_
);
ptr
->
p_gpu_weight_
=
p_gpu_weight_
;
ptr
->
p_gpu_weight_
=
p_gpu_weight_
;
return
ptr
;
return
ptr
;
}
}
...
@@ -108,8 +114,8 @@ REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);
...
@@ -108,8 +114,8 @@ REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);
class
PReluPluginDynamic
:
public
DynamicPluginTensorRT
{
class
PReluPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
public:
PReluPluginDynamic
(
const
float
*
weight
,
const
int
weight_num
,
PReluPluginDynamic
(
const
float
*
weight
,
const
int
weight_num
,
std
::
string
const
&
mode
)
std
::
string
const
&
mode
,
std
::
string
const
&
data_format
)
:
mode_
(
mode
)
{
:
mode_
(
mode
)
,
data_format_
(
data_format
)
{
weight_
.
resize
(
weight_num
);
weight_
.
resize
(
weight_num
);
std
::
copy
(
weight
,
weight
+
weight_num
,
weight_
.
data
());
std
::
copy
(
weight
,
weight
+
weight_num
,
weight_
.
data
());
}
}
...
@@ -117,7 +123,8 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
...
@@ -117,7 +123,8 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
PReluPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
);
PReluPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
);
~
PReluPluginDynamic
()
{}
~
PReluPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
ptr
=
new
PReluPluginDynamic
(
weight_
.
data
(),
weight_
.
size
(),
mode_
);
auto
ptr
=
new
PReluPluginDynamic
(
weight_
.
data
(),
weight_
.
size
(),
mode_
,
data_format_
);
ptr
->
p_gpu_weight_
=
p_gpu_weight_
;
ptr
->
p_gpu_weight_
=
p_gpu_weight_
;
return
ptr
;
return
ptr
;
}
}
...
@@ -167,6 +174,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
...
@@ -167,6 +174,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
std
::
vector
<
float
>
weight_
;
std
::
vector
<
float
>
weight_
;
float
*
p_gpu_weight_
;
float
*
p_gpu_weight_
;
std
::
string
mode_
;
std
::
string
mode_
;
std
::
string
data_format_
;
};
};
#endif
#endif
...
...
paddle/fluid/operators/math/prelu.cu
浏览文件 @
3f2a665a
...
@@ -25,7 +25,7 @@ inline static int PADDLE_GET_BLOCKS(const int N) {
...
@@ -25,7 +25,7 @@ inline static int PADDLE_GET_BLOCKS(const int N) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
PReluChannelWiseKernel
(
const
T
*
input
,
const
T
*
alpha
,
__global__
void
PReluChannel
First
WiseKernel
(
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
size_t
channel_num
,
T
*
output
,
size_t
channel_num
,
size_t
plane_size
,
size_t
numel
)
{
size_t
plane_size
,
size_t
numel
)
{
CUDA_KERNEL_LOOP
(
index
,
numel
)
{
CUDA_KERNEL_LOOP
(
index
,
numel
)
{
...
@@ -38,6 +38,19 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
...
@@ -38,6 +38,19 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
}
}
}
}
template
<
typename
T
>
__global__
void
PReluChannelLastWiseKernel
(
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
size_t
channel_num
,
size_t
numel
)
{
CUDA_KERNEL_LOOP
(
index
,
numel
)
{
size_t
channel_index
=
index
%
channel_num
;
T
scale
=
alpha
[
channel_index
];
T
x
=
input
[
index
];
T
zero
=
static_cast
<
T
>
(
0
);
output
[
index
]
=
(
x
>
zero
)
?
x
:
scale
*
x
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
PReluElementWiseKernel
(
const
T
*
input
,
const
T
*
alpha
,
__global__
void
PReluElementWiseKernel
(
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
size_t
spatial_size
,
T
*
output
,
size_t
spatial_size
,
...
@@ -65,10 +78,16 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
...
@@ -65,10 +78,16 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
template
<
typename
T
>
template
<
typename
T
>
void
PreluChannelWiseDirectCUDAFunctor
<
T
>::
operator
()(
void
PreluChannelWiseDirectCUDAFunctor
<
T
>::
operator
()(
gpuStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
gpuStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
size_t
batch_size
,
size_t
channel
,
size_t
numel
)
{
size_t
batch_size
,
size_t
channel
,
bool
channel_last
,
size_t
numel
)
{
PReluChannelWiseKernel
<<<
PADDLE_GET_BLOCKS
(
numel
),
CUDA_NUM_THREADS
,
0
,
if
(
channel_last
)
{
PReluChannelLastWiseKernel
<<<
PADDLE_GET_BLOCKS
(
numel
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
channel
,
stream
>>>
(
input
,
alpha
,
output
,
channel
,
numel
/
batch_size
/
channel
,
numel
);
numel
);
}
else
{
PReluChannelFirstWiseKernel
<<<
PADDLE_GET_BLOCKS
(
numel
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
channel
,
numel
/
batch_size
/
channel
,
numel
);
}
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/math/prelu.h
浏览文件 @
3f2a665a
...
@@ -31,7 +31,8 @@ template <typename T>
...
@@ -31,7 +31,8 @@ template <typename T>
class
PreluChannelWiseDirectCUDAFunctor
{
class
PreluChannelWiseDirectCUDAFunctor
{
public:
public:
void
operator
()(
gpuStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
void
operator
()(
gpuStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
size_t
batch_size
,
size_t
channel
,
size_t
numel
);
size_t
batch_size
,
size_t
channel
,
bool
channel_last
,
size_t
numel
);
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc
浏览文件 @
3f2a665a
...
@@ -34,7 +34,7 @@ class PReluMKLDNNHandler
...
@@ -34,7 +34,7 @@ class PReluMKLDNNHandler
const
dnnl
::
engine
engine
,
platform
::
Place
cpu_place
,
const
dnnl
::
engine
engine
,
platform
::
Place
cpu_place
,
const
Tensor
*
x
,
const
Tensor
*
weights
,
const
Tensor
*
x
,
const
Tensor
*
weights
,
const
std
::
string
&
uniq_name
,
const
std
::
string
&
mode
,
const
std
::
string
&
uniq_name
,
const
std
::
string
&
mode
,
bool
is_test
=
false
)
const
std
::
string
&
data_format
,
bool
is_test
=
false
)
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
prelu_forward
,
dnnl
::
prelu_backward
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
prelu_forward
,
dnnl
::
prelu_backward
>
(
dev_ctx
,
engine
,
cpu_place
,
dev_ctx
,
engine
,
cpu_place
,
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
x
->
dims
()),
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
x
->
dims
()),
...
@@ -49,9 +49,14 @@ class PReluMKLDNNHandler
...
@@ -49,9 +49,14 @@ class PReluMKLDNNHandler
if
(
weights
->
dims
().
size
()
!=
x
->
dims
().
size
())
{
if
(
weights
->
dims
().
size
()
!=
x
->
dims
().
size
())
{
auto
new_weights_dims
=
std
::
vector
<
int64_t
>
(
x
->
dims
().
size
(),
1
);
auto
new_weights_dims
=
std
::
vector
<
int64_t
>
(
x
->
dims
().
size
(),
1
);
if
(
mode
==
"channel"
)
{
if
(
mode
==
"channel"
)
{
if
(
data_format
==
"NHWC"
)
{
new_weights_dims
[
x
->
dims
().
size
()
-
1
]
=
*
std
::
max_element
(
weights_dims
.
begin
(),
weights_dims
.
end
());
}
else
{
new_weights_dims
[
1
]
=
new_weights_dims
[
1
]
=
*
std
::
max_element
(
weights_dims
.
begin
(),
weights_dims
.
end
());
*
std
::
max_element
(
weights_dims
.
begin
(),
weights_dims
.
end
());
}
}
}
weights_dims
=
std
::
move
(
new_weights_dims
);
weights_dims
=
std
::
move
(
new_weights_dims
);
}
}
auto
weights_md
=
memory
::
desc
(
weights_dims
,
MKLDNNGetDataType
<
T
>
(),
auto
weights_md
=
memory
::
desc
(
weights_dims
,
MKLDNNGetDataType
<
T
>
(),
...
@@ -110,9 +115,11 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -110,9 +115,11 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> {
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
auto
mode
=
ctx
.
Attr
<
std
::
string
>
(
"mode"
);
const
auto
mode
=
ctx
.
Attr
<
std
::
string
>
(
"mode"
);
const
auto
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PReluMKLDNNHandler
<
T
>
handler
(
dev_ctx
,
onednn_engine
,
ctx
.
GetPlace
(),
x
,
PReluMKLDNNHandler
<
T
>
handler
(
dev_ctx
,
onednn_engine
,
ctx
.
GetPlace
(),
x
,
alpha
,
ctx
.
InputName
(
"X"
),
mode
,
is_test
);
alpha
,
ctx
.
InputName
(
"X"
),
mode
,
data_format
,
is_test
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
auto
weights_memory_p
=
auto
weights_memory_p
=
...
@@ -149,9 +156,11 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -149,9 +156,11 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> {
auto
*
alpha
=
ctx
.
Input
<
Tensor
>
(
"Alpha"
);
auto
*
alpha
=
ctx
.
Input
<
Tensor
>
(
"Alpha"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
auto
mode
=
ctx
.
Attr
<
std
::
string
>
(
"mode"
);
const
auto
mode
=
ctx
.
Attr
<
std
::
string
>
(
"mode"
);
const
auto
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PReluMKLDNNHandler
<
T
>
handler
(
dev_ctx
,
onednn_engine
,
ctx
.
GetPlace
(),
x
,
PReluMKLDNNHandler
<
T
>
handler
(
dev_ctx
,
onednn_engine
,
ctx
.
GetPlace
(),
x
,
alpha
,
framework
::
GradVarName
(
"X"
),
mode
);
alpha
,
framework
::
GradVarName
(
"X"
),
mode
,
data_format
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
auto
weights_memory_p
=
auto
weights_memory_p
=
...
...
paddle/fluid/operators/prelu_op.cc
浏览文件 @
3f2a665a
...
@@ -38,12 +38,6 @@ class PReluOp : public framework::OperatorWithKernel {
...
@@ -38,12 +38,6 @@ class PReluOp : public framework::OperatorWithKernel {
"But recevied alpha's size: %d."
,
"But recevied alpha's size: %d."
,
product
(
ctx
->
GetInputDim
(
"Alpha"
))));
product
(
ctx
->
GetInputDim
(
"Alpha"
))));
}
else
if
(
mode
==
"channel"
)
{
}
else
if
(
mode
==
"channel"
)
{
PADDLE_ENFORCE_EQ
(
product
(
ctx
->
GetInputDim
(
"Alpha"
)),
x_dim
[
1
],
platform
::
errors
::
InvalidArgument
(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d"
,
product
(
ctx
->
GetInputDim
(
"Alpha"
)),
x_dim
[
1
]));
auto
x_rank
=
x_dim
.
size
();
auto
x_rank
=
x_dim
.
size
();
PADDLE_ENFORCE_GE
(
x_rank
,
2
,
PADDLE_ENFORCE_GE
(
x_rank
,
2
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -51,6 +45,33 @@ class PReluOp : public framework::OperatorWithKernel {
...
@@ -51,6 +45,33 @@ class PReluOp : public framework::OperatorWithKernel {
"equal or larger than 2. But recevied X's "
"equal or larger than 2. But recevied X's "
"rank: %d"
,
"rank: %d"
,
x_rank
));
x_rank
));
const
std
::
string
data_format_str
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_EQ
(
data_format_str
==
"NCHW"
||
data_format_str
==
"NHWC"
,
true
,
platform
::
errors
::
InvalidArgument
(
"For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s"
,
data_format_str
));
if
(
data_format_str
==
"NCHW"
)
{
PADDLE_ENFORCE_EQ
(
product
(
ctx
->
GetInputDim
(
"Alpha"
))
==
x_dim
[
1
],
true
,
platform
::
errors
::
InvalidArgument
(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d"
,
product
(
ctx
->
GetInputDim
(
"Alpha"
)),
x_dim
[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
product
(
ctx
->
GetInputDim
(
"Alpha"
))
==
x_dim
[
x_rank
-
1
],
true
,
platform
::
errors
::
InvalidArgument
(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[%d]: %d"
,
product
(
ctx
->
GetInputDim
(
"Alpha"
)),
x_rank
-
1
,
x_dim
[
x_rank
-
1
]));
}
}
else
if
(
mode
==
"element"
)
{
}
else
if
(
mode
==
"element"
)
{
auto
alpha_dim
=
ctx
->
GetInputDim
(
"Alpha"
);
auto
alpha_dim
=
ctx
->
GetInputDim
(
"Alpha"
);
auto
alpha_rank
=
alpha_dim
.
size
();
auto
alpha_rank
=
alpha_dim
.
size
();
...
@@ -134,6 +155,9 @@ There are modes:
...
@@ -134,6 +155,9 @@ There are modes:
)DOC"
);
)DOC"
);
AddAttr
<
std
::
string
>
(
"mode"
,
"The mode for inputs to share weights."
)
AddAttr
<
std
::
string
>
(
"mode"
,
"The mode for inputs to share weights."
)
.
SetDefault
(
"all"
);
.
SetDefault
(
"all"
);
AddAttr
<
std
::
string
>
(
"data_format"
,
"Data format that specifies the layout of input"
)
.
SetDefault
(
"NCHW"
);
AddAttr
<
bool
>
(
"use_mkldnn"
,
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Only used in mkldnn kernel"
)
"(bool, default false) Only used in mkldnn kernel"
)
.
SetDefault
(
false
)
.
SetDefault
(
false
)
...
...
paddle/fluid/operators/prelu_op.cu
浏览文件 @
3f2a665a
...
@@ -42,17 +42,22 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
...
@@ -42,17 +42,22 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
const
T
*
alpha_ptr
=
alpha
->
data
<
T
>
();
const
T
*
alpha_ptr
=
alpha
->
data
<
T
>
();
auto
&
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
auto
&
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
auto
&
data_format
=
context
.
Attr
<
std
::
string
>
(
"data_format"
);
int
numel
=
x
->
numel
();
int
numel
=
x
->
numel
();
auto
dim
=
x
->
dims
();
auto
dim
=
x
->
dims
();
auto
x_rank
=
dim
.
size
();
VLOG
(
4
)
<<
"dim[0]:"
<<
dim
[
0
]
<<
", dim[1]:"
<<
dim
[
1
]
VLOG
(
4
)
<<
"dim[0]:"
<<
dim
[
0
]
<<
", dim[1]:"
<<
dim
[
1
]
<<
", dim["
<<
", numel:"
<<
numel
;
<<
x_rank
-
1
<<
"]:"
<<
dim
[
x_rank
-
1
]
<<
", numel:"
<<
numel
;
if
(
mode
==
"channel"
)
{
if
(
mode
==
"channel"
)
{
bool
channel_last
=
data_format
==
"NHWC"
;
size_t
channel
=
channel_last
?
dim
[
x_rank
-
1
]
:
dim
[
1
];
math
::
PreluChannelWiseDirectCUDAFunctor
<
T
>
prelu_channel_wise
;
math
::
PreluChannelWiseDirectCUDAFunctor
<
T
>
prelu_channel_wise
;
prelu_channel_wise
(
context
.
cuda_device_context
().
stream
(),
x_ptr
,
prelu_channel_wise
(
context
.
cuda_device_context
().
stream
(),
x_ptr
,
alpha_ptr
,
o_ptr
,
dim
[
0
],
dim
[
1
],
numel
);
alpha_ptr
,
o_ptr
,
dim
[
0
],
channel
,
channel_last
,
numel
);
}
else
if
(
mode
==
"element"
)
{
}
else
if
(
mode
==
"element"
)
{
math
::
PreluElementWiseDirectCUDAFunctor
<
T
>
prelu_element_wise
;
math
::
PreluElementWiseDirectCUDAFunctor
<
T
>
prelu_element_wise
;
prelu_element_wise
(
context
.
cuda_device_context
().
stream
(),
x_ptr
,
prelu_element_wise
(
context
.
cuda_device_context
().
stream
(),
x_ptr
,
...
@@ -65,7 +70,7 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
...
@@ -65,7 +70,7 @@ class CUDAPReluKernel : public framework::OpKernel<T> {
}
}
};
};
enum
PRELU_MODE
{
Element
,
Channel
,
Scalar
};
enum
PRELU_MODE
{
Element
,
Channel
First
,
ChannelLast
,
Scalar
};
template
<
typename
T
>
template
<
typename
T
>
__global__
void
PReluOpGradKernel
(
const
T
*
x_ptr
,
const
T
*
alpha_ptr
,
__global__
void
PReluOpGradKernel
(
const
T
*
x_ptr
,
const
T
*
alpha_ptr
,
...
@@ -78,10 +83,13 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
...
@@ -78,10 +83,13 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
if
(
mode
==
Element
)
{
if
(
mode
==
Element
)
{
size_t
element_index
=
index
%
spatial_size
;
size_t
element_index
=
index
%
spatial_size
;
scale
=
alpha_ptr
[
element_index
];
scale
=
alpha_ptr
[
element_index
];
}
else
if
(
mode
==
Channel
)
{
}
else
if
(
mode
==
Channel
First
)
{
size_t
temp
=
index
/
plane_size
;
size_t
temp
=
index
/
plane_size
;
size_t
channel_index
=
temp
%
channel_num
;
size_t
channel_index
=
temp
%
channel_num
;
scale
=
alpha_ptr
[
channel_index
];
scale
=
alpha_ptr
[
channel_index
];
}
else
if
(
mode
==
ChannelLast
)
{
size_t
channel_index
=
index
%
channel_num
;
scale
=
alpha_ptr
[
channel_index
];
}
else
{
}
else
{
scale
=
alpha_ptr
[
0
];
scale
=
alpha_ptr
[
0
];
}
}
...
@@ -105,11 +113,13 @@ class PreluOpGradFunctor {
...
@@ -105,11 +113,13 @@ class PreluOpGradFunctor {
}
}
size_t
plane_size
=
numel
/
input_dims
[
0
]
/
input_dims
[
1
];
size_t
plane_size
=
numel
/
input_dims
[
0
]
/
input_dims
[
1
];
size_t
spatial_size
=
numel
/
input_dims
[
0
];
size_t
spatial_size
=
numel
/
input_dims
[
0
];
size_t
channel
=
mode
==
ChannelLast
?
input_dims
[
input_dims
.
size
()
-
1
]
:
input_dims
[
1
];
PReluOpGradKernel
<
PReluOpGradKernel
<
T
><<<
PADDLE_GET_BLOCKS
(
numel
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
T
><<<
PADDLE_GET_BLOCKS
(
numel
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
x
,
alpha
,
dy
,
dx
,
dalpha
,
input_dims
[
1
],
plane_size
,
spatial_size
,
x
,
alpha
,
dy
,
dx
,
dalpha
,
channel
,
plane_size
,
spatial_size
,
numel
,
numel
,
mode
);
mode
);
}
}
};
};
...
@@ -140,9 +150,11 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
...
@@ -140,9 +150,11 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
if
(
!
dx
&&
!
dalpha
)
return
;
if
(
!
dx
&&
!
dalpha
)
return
;
auto
&
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
auto
&
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
auto
&
data_format
=
context
.
Attr
<
std
::
string
>
(
"data_format"
);
int
numel
=
x
->
numel
();
int
numel
=
x
->
numel
();
auto
dim
=
x
->
dims
();
auto
dim
=
x
->
dims
();
auto
x_rank
=
dim
.
size
();
std
::
vector
<
int
>
input_shape
=
framework
::
vectorize
<
int
>
(
dim
);
std
::
vector
<
int
>
input_shape
=
framework
::
vectorize
<
int
>
(
dim
);
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
stream
=
context
.
cuda_device_context
().
stream
();
...
@@ -157,10 +169,12 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
...
@@ -157,10 +169,12 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
}
}
PRELU_MODE
m
;
PRELU_MODE
m
;
bool
channel_last
=
false
;
if
(
mode
==
"element"
)
{
if
(
mode
==
"element"
)
{
m
=
Element
;
m
=
Element
;
}
else
if
(
mode
==
"channel"
)
{
}
else
if
(
mode
==
"channel"
)
{
m
=
Channel
;
channel_last
=
data_format
==
"NHWC"
;
m
=
channel_last
?
ChannelLast
:
ChannelFirst
;
}
else
{
}
else
{
m
=
Scalar
;
m
=
Scalar
;
}
}
...
@@ -172,7 +186,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
...
@@ -172,7 +186,8 @@ class CUDAPReluGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
reduce_dims
;
std
::
vector
<
int
>
reduce_dims
;
for
(
size_t
i
=
0
;
i
<
dim
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
dim
.
size
();
i
++
)
{
if
(
mode
==
"channel"
&&
i
==
1
)
continue
;
if
(
mode
==
"channel"
&&
!
channel_last
&&
i
==
1
)
continue
;
if
(
mode
==
"channel"
&&
channel_last
&&
i
==
dim
.
size
()
-
1
)
continue
;
if
(
mode
==
"element"
&&
i
!=
0
)
continue
;
if
(
mode
==
"element"
&&
i
!=
0
)
continue
;
reduce_dims
.
push_back
(
i
);
reduce_dims
.
push_back
(
i
);
}
}
...
...
paddle/fluid/operators/prelu_op.h
浏览文件 @
3f2a665a
...
@@ -33,12 +33,14 @@ class PReluKernel : public framework::OpKernel<T> {
...
@@ -33,12 +33,14 @@ class PReluKernel : public framework::OpKernel<T> {
const
T
*
alpha_ptr
=
alpha
->
data
<
T
>
();
const
T
*
alpha_ptr
=
alpha
->
data
<
T
>
();
auto
&
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
auto
&
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
auto
&
data_format
=
context
.
Attr
<
std
::
string
>
(
"data_format"
);
int
numel
=
x
->
numel
();
int
numel
=
x
->
numel
();
auto
dim
=
x
->
dims
();
auto
dim
=
x
->
dims
();
int
index
=
0
;
int
index
=
0
;
int
i
=
0
;
int
i
=
0
;
if
(
mode
==
"channel"
)
{
if
(
mode
==
"channel"
)
{
if
(
data_format
==
"NCHW"
)
{
int
temp
=
1
;
int
temp
=
1
;
for
(
int
j
=
2
;
j
<
dim
.
size
();
j
++
)
{
for
(
int
j
=
2
;
j
<
dim
.
size
();
j
++
)
{
temp
*=
dim
[
j
];
temp
*=
dim
[
j
];
...
@@ -47,6 +49,12 @@ class PReluKernel : public framework::OpKernel<T> {
...
@@ -47,6 +49,12 @@ class PReluKernel : public framework::OpKernel<T> {
index
=
(
i
/
temp
)
%
dim
[
1
];
index
=
(
i
/
temp
)
%
dim
[
1
];
o_ptr
[
i
]
=
x_ptr
[
i
]
>
0
?
x_ptr
[
i
]
:
alpha_ptr
[
index
]
*
x_ptr
[
i
];
o_ptr
[
i
]
=
x_ptr
[
i
]
>
0
?
x_ptr
[
i
]
:
alpha_ptr
[
index
]
*
x_ptr
[
i
];
}
}
}
else
{
for
(
i
=
0
;
i
<
numel
;
i
++
)
{
index
=
i
%
dim
[
dim
.
size
()
-
1
];
o_ptr
[
i
]
=
x_ptr
[
i
]
>
0
?
x_ptr
[
i
]
:
alpha_ptr
[
index
]
*
x_ptr
[
i
];
}
}
}
else
if
(
mode
==
"element"
)
{
}
else
if
(
mode
==
"element"
)
{
int
temp
=
1
;
int
temp
=
1
;
for
(
int
j
=
1
;
j
<
dim
.
size
();
j
++
)
{
for
(
int
j
=
1
;
j
<
dim
.
size
();
j
++
)
{
...
@@ -77,6 +85,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
...
@@ -77,6 +85,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
const
T
*
x_ptr
=
x
->
data
<
T
>
();
const
T
*
x_ptr
=
x
->
data
<
T
>
();
const
T
*
dout_ptr
=
dout
->
data
<
T
>
();
const
T
*
dout_ptr
=
dout
->
data
<
T
>
();
std
::
string
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
std
::
string
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
auto
&
data_format
=
context
.
Attr
<
std
::
string
>
(
"data_format"
);
int
numel
=
x
->
numel
();
int
numel
=
x
->
numel
();
auto
dim
=
x
->
dims
();
auto
dim
=
x
->
dims
();
int
index
=
0
;
int
index
=
0
;
...
@@ -84,6 +93,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
...
@@ -84,6 +93,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
if
(
dx
)
{
if
(
dx
)
{
T
*
dx_ptr
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
dx_ptr
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
mode
==
"channel"
)
{
if
(
mode
==
"channel"
)
{
if
(
data_format
==
"NCHW"
)
{
int
temp
=
1
;
int
temp
=
1
;
for
(
int
j
=
2
;
j
<
dim
.
size
();
j
++
)
{
for
(
int
j
=
2
;
j
<
dim
.
size
();
j
++
)
{
temp
*=
dim
[
j
];
temp
*=
dim
[
j
];
...
@@ -93,6 +103,13 @@ class PReluGradKernel : public framework::OpKernel<T> {
...
@@ -93,6 +103,13 @@ class PReluGradKernel : public framework::OpKernel<T> {
dx_ptr
[
i
]
=
dx_ptr
[
i
]
=
x_ptr
[
i
]
>
0
?
dout_ptr
[
i
]
:
alpha_ptr
[
index
]
*
dout_ptr
[
i
];
x_ptr
[
i
]
>
0
?
dout_ptr
[
i
]
:
alpha_ptr
[
index
]
*
dout_ptr
[
i
];
}
}
}
else
{
for
(
i
=
0
;
i
<
numel
;
i
++
)
{
index
=
i
%
dim
[
dim
.
size
()
-
1
];
dx_ptr
[
i
]
=
x_ptr
[
i
]
>
0
?
dout_ptr
[
i
]
:
alpha_ptr
[
index
]
*
dout_ptr
[
i
];
}
}
}
else
if
(
mode
==
"element"
)
{
}
else
if
(
mode
==
"element"
)
{
int
temp
=
1
;
int
temp
=
1
;
for
(
int
j
=
1
;
j
<
dim
.
size
();
j
++
)
{
for
(
int
j
=
1
;
j
<
dim
.
size
();
j
++
)
{
...
@@ -116,6 +133,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
...
@@ -116,6 +133,7 @@ class PReluGradKernel : public framework::OpKernel<T> {
memset
(
dalpha_ptr
,
0
,
sizeof
(
T
)
*
dalpha
->
numel
());
memset
(
dalpha_ptr
,
0
,
sizeof
(
T
)
*
dalpha
->
numel
());
if
(
mode
==
"channel"
)
{
if
(
mode
==
"channel"
)
{
if
(
data_format
==
"NCHW"
)
{
int
temp
=
1
;
int
temp
=
1
;
for
(
int
j
=
2
;
j
<
dim
.
size
();
j
++
)
{
for
(
int
j
=
2
;
j
<
dim
.
size
();
j
++
)
{
temp
*=
dim
[
j
];
temp
*=
dim
[
j
];
...
@@ -124,6 +142,12 @@ class PReluGradKernel : public framework::OpKernel<T> {
...
@@ -124,6 +142,12 @@ class PReluGradKernel : public framework::OpKernel<T> {
index
=
(
i
/
temp
)
%
dim
[
1
];
index
=
(
i
/
temp
)
%
dim
[
1
];
dalpha_ptr
[
index
]
+=
x_ptr
[
i
]
>
0
?
0
:
x_ptr
[
i
]
*
dout_ptr
[
i
];
dalpha_ptr
[
index
]
+=
x_ptr
[
i
]
>
0
?
0
:
x_ptr
[
i
]
*
dout_ptr
[
i
];
}
}
}
else
{
for
(
i
=
0
;
i
<
numel
;
i
++
)
{
index
=
i
%
dim
[
dim
.
size
()
-
1
];
dalpha_ptr
[
index
]
+=
x_ptr
[
i
]
>
0
?
0
:
x_ptr
[
i
]
*
dout_ptr
[
i
];
}
}
}
else
if
(
mode
==
"element"
)
{
}
else
if
(
mode
==
"element"
)
{
int
temp
=
1
;
int
temp
=
1
;
for
(
int
j
=
1
;
j
<
dim
.
size
();
j
++
)
{
for
(
int
j
=
1
;
j
<
dim
.
size
();
j
++
)
{
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
3f2a665a
...
@@ -9791,7 +9791,7 @@ def swish(x, beta=1.0, name=None):
...
@@ -9791,7 +9791,7 @@ def swish(x, beta=1.0, name=None):
@deprecated(since="2.0.0", update_to="paddle.static.nn.prelu")
@deprecated(since="2.0.0", update_to="paddle.static.nn.prelu")
def prelu(x, mode, param_attr=None, name=None):
def prelu(x, mode, param_attr=None,
data_format="NCHW",
name=None):
r"""
r"""
prelu activation.
prelu activation.
...
@@ -9819,6 +9819,9 @@ def prelu(x, mode, param_attr=None, name=None):
...
@@ -9819,6 +9819,9 @@ def prelu(x, mode, param_attr=None, name=None):
name (str, optional): Name for the operation (optional, default is None). \
name (str, optional): Name for the operation (optional, default is None). \
For more information, please refer to :ref:`api_guide_Name`.
For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Returns:
Returns:
Tensor: A tensor with the same shape and data type as x.
Tensor: A tensor with the same shape and data type as x.
...
@@ -9839,17 +9842,32 @@ def prelu(x, mode, param_attr=None, name=None):
...
@@ -9839,17 +9842,32 @@ def prelu(x, mode, param_attr=None, name=None):
helper = LayerHelper('prelu', **locals())
helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']:
if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.')
raise ValueError('mode should be one of all, channel, element.')
alpha_shape = [1]
alpha_shape = [1]
# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
if mode == 'channel':
if mode == 'channel':
true_data_format = [
'NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC'
]
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
"'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format))
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
assert len(
assert len(
x.shape
x.shape
) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
#NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
#NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
# To be consistent with Prelu, it is simplified.
# To be consistent with Prelu, it is simplified.
#NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
#NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
#NOTE(GuoxiaWang): support NHWC data format
if data_format == 'NHWC':
alpha_shape = [1, 1, 1, x.shape[1]]
else:
alpha_shape = [1, x.shape[1], 1, 1]
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element':
elif mode == 'element':
assert len(
assert len(
x.shape
x.shape
...
@@ -9867,7 +9885,8 @@ def prelu(x, mode, param_attr=None, name=None):
...
@@ -9867,7 +9885,8 @@ def prelu(x, mode, param_attr=None, name=None):
type="prelu",
type="prelu",
inputs={"X": x,
inputs={"X": x,
'Alpha': alpha},
'Alpha': alpha},
attrs={"mode": mode},
attrs={"mode": mode,
"data_format": data_format},
outputs={"Out": out})
outputs={"Out": out})
return out
return out
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py
浏览文件 @
3f2a665a
...
@@ -44,8 +44,12 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
...
@@ -44,8 +44,12 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
if
len
(
kwargs
[
'in_shape'
])
<=
1
:
if
len
(
kwargs
[
'in_shape'
])
<=
1
:
# not valid case, just return 0
# not valid case, just return 0
return
np
.
zeros
((
1
)).
astype
(
np
.
float32
)
return
np
.
zeros
((
1
)).
astype
(
np
.
float32
)
if
kwargs
[
'data_format'
]
==
'NCHW'
:
return
np
.
random
.
random
(
kwargs
[
'in_shape'
][
1
]).
astype
(
return
np
.
random
.
random
(
kwargs
[
'in_shape'
][
1
]).
astype
(
np
.
float32
)
np
.
float32
)
else
:
return
np
.
random
.
random
(
kwargs
[
'in_shape'
][
-
1
]).
astype
(
np
.
float32
)
else
:
else
:
if
len
(
kwargs
[
'in_shape'
])
<=
1
:
if
len
(
kwargs
[
'in_shape'
])
<=
1
:
# not valid case, just return 0
# not valid case, just return 0
...
@@ -57,7 +61,10 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
...
@@ -57,7 +61,10 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
inputs
=
{
"X"
:
[
"input_data"
],
inputs
=
{
"X"
:
[
"input_data"
],
"Alpha"
:
[
"alpha_weight"
]},
"Alpha"
:
[
"alpha_weight"
]},
outputs
=
{
"Out"
:
[
"output_data"
]},
outputs
=
{
"Out"
:
[
"output_data"
]},
attrs
=
{
"mode"
:
kwargs
[
'mode'
]})
attrs
=
{
"mode"
:
kwargs
[
'mode'
],
"data_format"
:
kwargs
[
'data_format'
]
})
program_config
=
ProgramConfig
(
program_config
=
ProgramConfig
(
ops
=
[
prelu_op
],
ops
=
[
prelu_op
],
...
@@ -82,6 +89,7 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
...
@@ -82,6 +89,7 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest):
@
given
(
@
given
(
mode
=
st
.
sampled_from
([
'all'
,
'channel'
,
'element'
]),
mode
=
st
.
sampled_from
([
'all'
,
'channel'
,
'element'
]),
data_format
=
st
.
sampled_from
([
'NCHW'
,
'NHWC'
]),
in_shape
=
st
.
lists
(
in_shape
=
st
.
lists
(
st
.
integers
(
st
.
integers
(
min_value
=
1
,
max_value
=
32
),
min_size
=
1
,
max_size
=
4
))
min_value
=
1
,
max_value
=
32
),
min_size
=
1
,
max_size
=
4
))
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py
浏览文件 @
3f2a665a
...
@@ -39,7 +39,8 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
...
@@ -39,7 +39,8 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
def
generate_alpha
(
attrs
:
List
[
Dict
[
str
,
Any
]],
dim1
,
dim2
,
dim3
):
def
generate_alpha
(
attrs
:
List
[
Dict
[
str
,
Any
]],
dim1
,
dim2
,
dim3
):
if
attrs
[
0
][
"mode"
]
==
"all"
:
if
attrs
[
0
][
"mode"
]
==
"all"
:
return
np
.
random
.
random
(
size
=
(
1
)).
astype
(
np
.
float32
)
return
np
.
random
.
random
(
size
=
(
1
)).
astype
(
np
.
float32
)
elif
attrs
[
0
][
"mode"
]
==
"channel"
:
elif
attrs
[
0
][
"mode"
]
==
"channel"
and
attrs
[
0
][
"data_format"
]
==
"NCHW"
:
shape
=
[
1
]
shape
=
[
1
]
if
dim1
!=
0
:
if
dim1
!=
0
:
shape
.
append
(
dim1
)
shape
.
append
(
dim1
)
...
@@ -48,6 +49,16 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
...
@@ -48,6 +49,16 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
if
dim3
!=
0
:
if
dim3
!=
0
:
shape
.
append
(
1
)
shape
.
append
(
1
)
return
np
.
random
.
random
(
size
=
shape
).
astype
(
np
.
float32
)
return
np
.
random
.
random
(
size
=
shape
).
astype
(
np
.
float32
)
elif
attrs
[
0
][
"mode"
]
==
"channel"
and
attrs
[
0
][
"data_format"
]
==
"NHWC"
:
shape
=
[
1
]
if
dim1
!=
0
:
shape
.
append
(
1
)
if
dim2
!=
0
:
shape
.
append
(
1
)
if
dim3
!=
0
:
shape
.
append
(
dim3
)
return
np
.
random
.
random
(
size
=
shape
).
astype
(
np
.
float32
)
elif
attrs
[
0
][
"mode"
]
==
"element"
:
elif
attrs
[
0
][
"mode"
]
==
"element"
:
shape
=
[
1
]
shape
=
[
1
]
if
dim1
!=
0
:
if
dim1
!=
0
:
...
@@ -72,9 +83,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
...
@@ -72,9 +83,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
continue
continue
for
mode
in
[
"all"
,
"channel"
,
"element"
]:
for
mode
in
[
"all"
,
"channel"
,
"element"
]:
if
mode
==
"channel"
and
dim1
==
0
:
for
data_format
in
[
'NCHW'
,
'NHWC'
]:
if
mode
==
"channel"
and
dim1
==
0
and
data_format
==
"NCHW"
:
continue
continue
dics
=
[{
"mode"
:
mode
}]
if
mode
==
"channel"
and
dim3
==
0
and
data_format
==
"NHWC"
:
continue
dics
=
[{
"mode"
:
mode
,
"data_format"
:
data_format
}]
ops_config
=
[{
ops_config
=
[{
"op_type"
:
"prelu"
,
"op_type"
:
"prelu"
,
"op_inputs"
:
{
"op_inputs"
:
{
...
@@ -92,13 +109,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
...
@@ -92,13 +109,15 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest):
ops
=
ops
,
ops
=
ops
,
weights
=
{
weights
=
{
"alpha_weight"
:
TensorConfig
(
"alpha_weight"
:
TensorConfig
(
data_gen
=
partial
(
generate_alpha
,
dics
,
data_gen
=
partial
(
generate_alpha
,
dim1
,
dim2
,
dim3
))
dics
,
dim1
,
dim2
,
dim3
))
},
},
inputs
=
{
inputs
=
{
"input_data"
:
TensorConfig
(
"input_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input
,
batch
,
data_gen
=
partial
(
generate_input
,
dim1
,
dim2
,
dim3
)),
batch
,
dim1
,
dim2
,
dim3
)),
},
},
outputs
=
[
"output_data"
])
outputs
=
[
"output_data"
])
...
...
python/paddle/fluid/tests/unittests/test_imperative_layers.py
浏览文件 @
3f2a665a
...
@@ -41,10 +41,11 @@ class TestLayerPrint(unittest.TestCase):
...
@@ -41,10 +41,11 @@ class TestLayerPrint(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
str
(
module
),
'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)'
)
str
(
module
),
'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)'
)
module
=
nn
.
PReLU
(
1
,
0.25
,
name
=
"PReLU"
)
module
=
nn
.
PReLU
(
1
,
0.25
,
name
=
"PReLU"
,
data_format
=
"NCHW"
)
self
.
assertEqual
(
self
.
assertEqual
(
str
(
module
),
str
(
module
),
'PReLU(num_parameters=1, init=0.25, dtype=float32, name=PReLU)'
)
'PReLU(num_parameters=1, data_format=NCHW, init=0.25, dtype=float32, name=PReLU)'
)
module
=
nn
.
ReLU
()
module
=
nn
.
ReLU
()
self
.
assertEqual
(
str
(
module
),
'ReLU()'
)
self
.
assertEqual
(
str
(
module
),
'ReLU()'
)
...
...
python/paddle/fluid/tests/unittests/test_prelu_op.py
浏览文件 @
3f2a665a
...
@@ -163,10 +163,18 @@ class PReluTest(OpTest):
...
@@ -163,10 +163,18 @@ class PReluTest(OpTest):
# zero.
# zero.
x_np
[
np
.
abs
(
x_np
)
<
0.005
]
=
0.02
x_np
[
np
.
abs
(
x_np
)
<
0.005
]
=
0.02
if
self
.
attrs
==
{
'mode'
:
"all"
}:
if
self
.
attrs
==
{
'mode'
:
"all"
,
"data_format"
:
"NCHW"
}
or
self
.
attrs
==
{
'mode'
:
"all"
,
"data_format"
:
"NHWC"
}:
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
(
1
))
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
(
1
))
elif
self
.
attrs
==
{
'mode'
:
"channel"
}:
elif
self
.
attrs
==
{
'mode'
:
"channel"
,
"data_format"
:
"NCHW"
}:
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
,
self
.
x_shape
[
1
],
1
,
1
])
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
,
self
.
x_shape
[
1
],
1
,
1
])
elif
self
.
attrs
==
{
'mode'
:
"channel"
,
"data_format"
:
"NHWC"
}:
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
,
1
,
1
,
self
.
x_shape
[
-
1
]])
else
:
else
:
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
]
+
self
.
x_shape
[
1
:])
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
]
+
self
.
x_shape
[
1
:])
alpha_np
=
alpha_np
.
astype
(
self
.
dtype
)
alpha_np
=
alpha_np
.
astype
(
self
.
dtype
)
...
@@ -176,11 +184,14 @@ class PReluTest(OpTest):
...
@@ -176,11 +184,14 @@ class PReluTest(OpTest):
# NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100, 1, 1] to [1, 100] + [1]*len(x.shape[2:])
# NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100, 1, 1] to [1, 100] + [1]*len(x.shape[2:])
# since np operands could not be broadcast together with shapes (1,100,2,2,2,3) (1,100,1,1)
# since np operands could not be broadcast together with shapes (1,100,2,2,2,3) (1,100,1,1)
reshaped_alpha
=
self
.
inputs
[
'Alpha'
]
reshaped_alpha
=
self
.
inputs
[
'Alpha'
]
if
self
.
attrs
==
{
'mode'
:
"channel"
}:
if
self
.
attrs
==
{
'mode'
:
"channel"
,
"data_format"
:
"NCHW"
}:
reshaped_alpha
=
np
.
reshape
(
reshaped_alpha
=
np
.
reshape
(
self
.
inputs
[
'Alpha'
],
self
.
inputs
[
'Alpha'
],
[
1
,
self
.
x_shape
[
1
]]
+
[
1
]
*
len
(
self
.
x_shape
[
2
:]))
[
1
,
self
.
x_shape
[
1
]]
+
[
1
]
*
len
(
self
.
x_shape
[
2
:]))
elif
self
.
attrs
==
{
'mode'
:
"channel"
,
"data_format"
:
"NHWC"
}:
reshaped_alpha
=
np
.
reshape
(
self
.
inputs
[
'Alpha'
],
[
1
]
+
[
1
]
*
len
(
self
.
x_shape
[
1
:
-
1
])
+
[
self
.
x_shape
[
-
1
]])
out_np
=
np
.
maximum
(
self
.
inputs
[
'X'
],
0.
)
out_np
=
np
.
maximum
(
self
.
inputs
[
'X'
],
0.
)
out_np
=
out_np
+
np
.
minimum
(
self
.
inputs
[
'X'
],
0.
)
*
reshaped_alpha
out_np
=
out_np
+
np
.
minimum
(
self
.
inputs
[
'X'
],
0.
)
*
reshaped_alpha
assert
out_np
is
not
self
.
inputs
[
'X'
]
assert
out_np
is
not
self
.
inputs
[
'X'
]
...
@@ -193,7 +204,7 @@ class PReluTest(OpTest):
...
@@ -193,7 +204,7 @@ class PReluTest(OpTest):
self
.
x_shape
=
[
2
,
100
,
3
,
4
]
self
.
x_shape
=
[
2
,
100
,
3
,
4
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"channel"
}
self
.
attrs
=
{
'mode'
:
"channel"
,
"data_format"
:
"NCHW"
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -210,7 +221,18 @@ class TestModeAll(PReluTest):
...
@@ -210,7 +221,18 @@ class TestModeAll(PReluTest):
self
.
x_shape
=
[
2
,
3
,
4
,
5
]
self
.
x_shape
=
[
2
,
3
,
4
,
5
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"all"
}
self
.
attrs
=
{
'mode'
:
"all"
,
"data_format"
:
"NCHW"
}
@
skip_check_grad_ci
(
reason
=
"[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class
TestModeAllNHWC
(
PReluTest
):
def
init_input_shape
(
self
):
self
.
x_shape
=
[
2
,
3
,
4
,
50
]
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"all"
,
"data_format"
:
"NHWC"
}
class
TestModeElt
(
PReluTest
):
class
TestModeElt
(
PReluTest
):
...
@@ -218,7 +240,15 @@ class TestModeElt(PReluTest):
...
@@ -218,7 +240,15 @@ class TestModeElt(PReluTest):
self
.
x_shape
=
[
3
,
2
,
5
,
10
]
self
.
x_shape
=
[
3
,
2
,
5
,
10
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"element"
}
self
.
attrs
=
{
'mode'
:
"element"
,
"data_format"
:
"NCHW"
}
class
TestModeEltNHWC
(
PReluTest
):
def
init_input_shape
(
self
):
self
.
x_shape
=
[
3
,
2
,
5
,
10
]
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"element"
,
"data_format"
:
"NHWC"
}
@
skip_check_grad_ci
(
@
skip_check_grad_ci
(
...
@@ -229,7 +259,18 @@ class TestModeAllRank3(PReluTest):
...
@@ -229,7 +259,18 @@ class TestModeAllRank3(PReluTest):
self
.
x_shape
=
[
1
,
200
,
3
]
self
.
x_shape
=
[
1
,
200
,
3
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"all"
}
self
.
attrs
=
{
'mode'
:
"all"
,
"data_format"
:
"NCHW"
}
@
skip_check_grad_ci
(
reason
=
"[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class
TestModeAllRank3NHWC
(
PReluTest
):
def
init_input_shape
(
self
):
self
.
x_shape
=
[
1
,
200
,
3
]
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"all"
,
"data_format"
:
"NHWC"
}
@
skip_check_grad_ci
(
@
skip_check_grad_ci
(
...
@@ -240,7 +281,18 @@ class TestModeAllRank6(PReluTest):
...
@@ -240,7 +281,18 @@ class TestModeAllRank6(PReluTest):
self
.
x_shape
=
[
1
,
2
,
3
,
4
,
5
,
6
]
self
.
x_shape
=
[
1
,
2
,
3
,
4
,
5
,
6
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"all"
}
self
.
attrs
=
{
'mode'
:
"all"
,
"data_format"
:
"NCHW"
}
@
skip_check_grad_ci
(
reason
=
"[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
)
class
TestModeAllRank6NHWC
(
PReluTest
):
def
init_input_shape
(
self
):
self
.
x_shape
=
[
1
,
2
,
3
,
4
,
5
,
6
]
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"all"
,
"data_format"
:
"NHWC"
}
class
TestModeChannelRank3
(
PReluTest
):
class
TestModeChannelRank3
(
PReluTest
):
...
@@ -248,7 +300,15 @@ class TestModeChannelRank3(PReluTest):
...
@@ -248,7 +300,15 @@ class TestModeChannelRank3(PReluTest):
self
.
x_shape
=
[
1
,
200
,
3
]
self
.
x_shape
=
[
1
,
200
,
3
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"channel"
}
self
.
attrs
=
{
'mode'
:
"channel"
,
"data_format"
:
"NCHW"
}
class
TestModeChannelRank3NHWC
(
PReluTest
):
def
init_input_shape
(
self
):
self
.
x_shape
=
[
1
,
3
,
100
]
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"channel"
,
"data_format"
:
"NHWC"
}
class
TestModeChannelRank6
(
PReluTest
):
class
TestModeChannelRank6
(
PReluTest
):
...
@@ -256,7 +316,15 @@ class TestModeChannelRank6(PReluTest):
...
@@ -256,7 +316,15 @@ class TestModeChannelRank6(PReluTest):
self
.
x_shape
=
[
1
,
100
,
2
,
2
,
2
,
2
]
self
.
x_shape
=
[
1
,
100
,
2
,
2
,
2
,
2
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"channel"
}
self
.
attrs
=
{
'mode'
:
"channel"
,
"data_format"
:
"NCHW"
}
class
TestModeChannelRank6NHWC
(
PReluTest
):
def
init_input_shape
(
self
):
self
.
x_shape
=
[
1
,
2
,
2
,
2
,
2
,
100
]
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"channel"
,
"data_format"
:
"NHWC"
}
class
TestModeElementRank3
(
PReluTest
):
class
TestModeElementRank3
(
PReluTest
):
...
@@ -264,7 +332,15 @@ class TestModeElementRank3(PReluTest):
...
@@ -264,7 +332,15 @@ class TestModeElementRank3(PReluTest):
self
.
x_shape
=
[
3
,
10
,
10
]
self
.
x_shape
=
[
3
,
10
,
10
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"element"
}
self
.
attrs
=
{
'mode'
:
"element"
,
"data_format"
:
"NCHW"
}
class
TestModeElementRank3NHWC
(
PReluTest
):
def
init_input_shape
(
self
):
self
.
x_shape
=
[
3
,
10
,
10
]
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"element"
,
"data_format"
:
"NHWC"
}
class
TestModeElementRank6
(
PReluTest
):
class
TestModeElementRank6
(
PReluTest
):
...
@@ -272,7 +348,15 @@ class TestModeElementRank6(PReluTest):
...
@@ -272,7 +348,15 @@ class TestModeElementRank6(PReluTest):
self
.
x_shape
=
[
3
,
2
,
2
,
4
,
5
,
2
]
self
.
x_shape
=
[
3
,
2
,
2
,
4
,
5
,
2
]
def
init_attr
(
self
):
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"element"
}
self
.
attrs
=
{
'mode'
:
"element"
,
"data_format"
:
"NCHW"
}
class
TestModeElementRank6NHWC
(
PReluTest
):
def
init_input_shape
(
self
):
self
.
x_shape
=
[
3
,
2
,
2
,
4
,
5
,
2
]
def
init_attr
(
self
):
self
.
attrs
=
{
'mode'
:
"element"
,
"data_format"
:
"NHWC"
}
def
create_test_fp16_class
(
parent
,
def
create_test_fp16_class
(
parent
,
...
@@ -311,9 +395,16 @@ create_test_fp16_class(TestModeChannelRank3)
...
@@ -311,9 +395,16 @@ create_test_fp16_class(TestModeChannelRank3)
create_test_fp16_class
(
TestModeChannelRank6
)
create_test_fp16_class
(
TestModeChannelRank6
)
create_test_fp16_class
(
TestModeElementRank3
)
create_test_fp16_class
(
TestModeElementRank3
)
create_test_fp16_class
(
TestModeElementRank6
)
create_test_fp16_class
(
TestModeElementRank6
)
create_test_fp16_class
(
TestModeEltNHWC
)
create_test_fp16_class
(
TestModeAllRank3NHWC
)
create_test_fp16_class
(
TestModeAllRank6NHWC
)
create_test_fp16_class
(
TestModeChannelRank3NHWC
)
create_test_fp16_class
(
TestModeChannelRank6NHWC
)
create_test_fp16_class
(
TestModeElementRank3NHWC
)
create_test_fp16_class
(
TestModeElementRank6NHWC
)
def
prelu_t
(
x
,
mode
,
param_attr
=
None
,
name
=
None
):
def
prelu_t
(
x
,
mode
,
param_attr
=
None
,
name
=
None
,
data_format
=
'NCHW'
):
helper
=
fluid
.
layer_helper
.
LayerHelper
(
'prelu'
,
**
locals
())
helper
=
fluid
.
layer_helper
.
LayerHelper
(
'prelu'
,
**
locals
())
alpha_shape
=
[
1
,
x
.
shape
[
1
],
1
,
1
]
alpha_shape
=
[
1
,
x
.
shape
[
1
],
1
,
1
]
dtype
=
helper
.
input_dtype
(
input_param_name
=
'x'
)
dtype
=
helper
.
input_dtype
(
input_param_name
=
'x'
)
...
@@ -328,13 +419,19 @@ def prelu_t(x, mode, param_attr=None, name=None):
...
@@ -328,13 +419,19 @@ def prelu_t(x, mode, param_attr=None, name=None):
type
=
"prelu"
,
type
=
"prelu"
,
inputs
=
{
"X"
:
x
,
inputs
=
{
"X"
:
x
,
'Alpha'
:
alpha
},
'Alpha'
:
alpha
},
attrs
=
{
"mode"
:
mode
},
attrs
=
{
"mode"
:
mode
,
'data_format'
:
data_format
},
outputs
=
{
"Out"
:
out
})
outputs
=
{
"Out"
:
out
})
return
out
return
out
# error message test if mode is not one of 'all', 'channel', 'element'
# error message test if mode is not one of 'all', 'channel', 'element'
class
TestModeError
(
unittest
.
TestCase
):
class
TestModeError
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
(
)
else
paddle
.
CPUPlace
()
self
.
x_np
=
np
.
ones
([
1
,
2
,
3
,
4
]).
astype
(
'float32'
)
def
test_mode_error
(
self
):
def
test_mode_error
(
self
):
main_program
=
Program
()
main_program
=
Program
()
with
fluid
.
program_guard
(
main_program
,
Program
()):
with
fluid
.
program_guard
(
main_program
,
Program
()):
...
@@ -344,6 +441,24 @@ class TestModeError(unittest.TestCase):
...
@@ -344,6 +441,24 @@ class TestModeError(unittest.TestCase):
except
Exception
as
e
:
except
Exception
as
e
:
assert
(
e
.
args
[
0
].
find
(
'InvalidArgument'
)
!=
-
1
)
assert
(
e
.
args
[
0
].
find
(
'InvalidArgument'
)
!=
-
1
)
def
test_data_format_error1
(
self
):
main_program
=
Program
()
with
fluid
.
program_guard
(
main_program
,
Program
()):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
2
,
3
,
4
,
5
])
try
:
y
=
prelu_t
(
x
,
'channel'
,
data_format
=
'N'
)
except
Exception
as
e
:
assert
(
e
.
args
[
0
].
find
(
'InvalidArgument'
)
!=
-
1
)
def
test_data_format_error2
(
self
):
main_program
=
Program
()
with
fluid
.
program_guard
(
main_program
,
Program
()):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
2
,
3
,
4
,
5
])
try
:
y
=
paddle
.
static
.
nn
.
prelu
(
x
,
'channel'
,
data_format
=
'N'
)
except
ValueError
as
e
:
pass
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/nn/functional/activation.py
浏览文件 @
3f2a665a
...
@@ -442,7 +442,7 @@ def leaky_relu(x, negative_slope=0.01, name=None):
...
@@ -442,7 +442,7 @@ def leaky_relu(x, negative_slope=0.01, name=None):
return
out
return
out
def
prelu
(
x
,
weight
,
name
=
None
):
def
prelu
(
x
,
weight
,
data_format
=
"NCHW"
,
name
=
None
):
"""
"""
prelu activation.
prelu activation.
...
@@ -456,6 +456,8 @@ def prelu(x, weight, name=None):
...
@@ -456,6 +456,8 @@ def prelu(x, weight, name=None):
The weight shape is [1] or [in], where `in` is the input channel of ``x``.
The weight shape is [1] or [in], where `in` is the input channel of ``x``.
name (str, optional): Name for the operation (optional, default is None).
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Returns:
Returns:
A Tensor with the same data type and shape as ``x`` .
A Tensor with the same data type and shape as ``x`` .
...
@@ -490,19 +492,34 @@ def prelu(x, weight, name=None):
...
@@ -490,19 +492,34 @@ def prelu(x, weight, name=None):
assert
len
(
weight
.
shape
assert
len
(
weight
.
shape
)
==
1
,
"The dim count of weight shape should be 1 in prelu()."
)
==
1
,
"The dim count of weight shape should be 1 in prelu()."
# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
mode
=
'all'
mode
=
'all'
if
weight
.
shape
[
0
]
>
1
:
if
weight
.
shape
[
0
]
>
1
:
true_data_format
=
[
'NC'
,
'NCL'
,
'NCHW'
,
'NCDHW'
,
'NLC'
,
'NHWC'
,
'NDHWC'
]
if
data_format
not
in
true_data_format
:
raise
ValueError
(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
"'NLC', 'NHWC', 'NDHWC' but receive {}"
.
format
(
data_format
))
data_format
=
'NCHW'
if
data_format
[
1
]
==
'C'
else
'NHWC'
assert
len
(
assert
len
(
x
.
shape
x
.
shape
)
>
1
,
"The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
)
>
1
,
"The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
#NOTE(GuoxiaWang): support NHWC data format
if
data_format
==
'NHWC'
:
assert
weight
.
shape
[
0
]
==
x
.
shape
[
-
1
],
"The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
else
:
assert
weight
.
shape
[
0
]
==
x
.
shape
[
assert
weight
.
shape
[
0
]
==
x
.
shape
[
1
],
"The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
1
],
"The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
mode
=
'channel'
mode
=
'channel'
if
in_dygraph_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
prelu
(
x
,
weight
,
'mode'
,
mode
)
return
_C_ops
.
prelu
(
x
,
weight
,
'mode'
,
mode
,
'data_format'
,
data_format
)
helper
=
LayerHelper
(
'prelu'
,
**
locals
())
helper
=
LayerHelper
(
'prelu'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
...
@@ -511,7 +528,8 @@ def prelu(x, weight, name=None):
...
@@ -511,7 +528,8 @@ def prelu(x, weight, name=None):
inputs
=
{
"X"
:
x
,
inputs
=
{
"X"
:
x
,
"Alpha"
:
weight
},
"Alpha"
:
weight
},
outputs
=
{
"Out"
:
out
},
outputs
=
{
"Out"
:
out
},
attrs
=
{
"mode"
:
mode
})
attrs
=
{
"mode"
:
mode
,
"data_format"
:
data_format
})
return
out
return
out
...
...
python/paddle/nn/layer/activation.py
浏览文件 @
3f2a665a
...
@@ -376,6 +376,8 @@ class PReLU(Layer):
...
@@ -376,6 +376,8 @@ class PReLU(Layer):
Default is None. For more information, please refer to :ref:`api_paddle_ParamAttr`.
Default is None. For more information, please refer to :ref:`api_paddle_ParamAttr`.
name (str, optional): Name for the operation (optional, default is None).
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
For more information, please refer to :ref:`api_guide_Name`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
Shape:
Shape:
- input: Tensor with any shape. Default dtype is float32.
- input: Tensor with any shape. Default dtype is float32.
...
@@ -406,13 +408,18 @@ class PReLU(Layer):
...
@@ -406,13 +408,18 @@ class PReLU(Layer):
# [ 6. , 7. , 8. , 9. ]]]]
# [ 6. , 7. , 8. , 9. ]]]]
"""
"""
def
__init__
(
self
,
num_parameters
=
1
,
init
=
0.25
,
weight_attr
=
None
,
def
__init__
(
self
,
num_parameters
=
1
,
init
=
0.25
,
weight_attr
=
None
,
data_format
=
"NCHW"
,
name
=
None
):
name
=
None
):
super
(
PReLU
,
self
).
__init__
()
super
(
PReLU
,
self
).
__init__
()
self
.
_num_parameters
=
num_parameters
self
.
_num_parameters
=
num_parameters
self
.
_init
=
init
self
.
_init
=
init
self
.
_weight_attr
=
weight_attr
self
.
_weight_attr
=
weight_attr
self
.
_name
=
name
self
.
_name
=
name
self
.
_data_format
=
data_format
self
.
_weight
=
self
.
create_parameter
(
self
.
_weight
=
self
.
create_parameter
(
attr
=
self
.
_weight_attr
,
attr
=
self
.
_weight_attr
,
...
@@ -422,12 +429,13 @@ class PReLU(Layer):
...
@@ -422,12 +429,13 @@ class PReLU(Layer):
default_initializer
=
Constant
(
self
.
_init
))
default_initializer
=
Constant
(
self
.
_init
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
F
.
prelu
(
x
,
self
.
_weight
)
return
F
.
prelu
(
x
,
self
.
_weight
,
data_format
=
self
.
_data_format
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
name_str
=
', name={}'
.
format
(
self
.
_name
)
if
self
.
_name
else
''
name_str
=
', name={}'
.
format
(
self
.
_name
)
if
self
.
_name
else
''
return
'num_parameters={}, init={}, dtype={}{}'
.
format
(
return
'num_parameters={}, data_format={}, init={}, dtype={}{}'
.
format
(
self
.
_num_parameters
,
self
.
_init
,
self
.
_dtype
,
name_str
)
self
.
_num_parameters
,
self
.
_data_format
,
self
.
_init
,
self
.
_dtype
,
name_str
)
class
ReLU
(
Layer
):
class
ReLU
(
Layer
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录