Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5b3dd806
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5b3dd806
编写于
4月 10, 2020
作者:
Z
zhupengyang
提交者:
GitHub
4月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Op(prelu) error message enhancement (#23616)
上级
0581d74d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
50 addition
and
23 deletion
+50
-23
paddle/fluid/operators/prelu_op.cc
paddle/fluid/operators/prelu_op.cc
+32
-23
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-0
python/paddle/fluid/tests/unittests/test_prelu_op.py
python/paddle/fluid/tests/unittests/test_prelu_op.py
+16
-0
未找到文件。
paddle/fluid/operators/prelu_op.cc
浏览文件 @
5b3dd806
...
...
@@ -24,40 +24,48 @@ class PReluOp : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
std
::
string
mode
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"mode"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"prelu"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Alpha"
),
"Input"
,
"Alpha"
,
"prelu"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"prelu"
);
auto
x_dim
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of PreluOp should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Alpha"
),
"Input(Alpha) of PreluOp should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of PreluOp should not be null"
);
std
::
string
mode
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"mode"
);
if
(
mode
==
"all"
)
{
PADDLE_ENFORCE
(
product
(
ctx
->
GetInputDim
(
"Alpha"
))
==
1
,
"For mode 'all', size of weight Alpha must be one."
);
PADDLE_ENFORCE_EQ
(
product
(
ctx
->
GetInputDim
(
"Alpha"
)),
1
,
platform
::
errors
::
InvalidArgument
(
"For mode 'all', size of weight Alpha must be one."
));
}
else
if
(
mode
==
"channel"
)
{
PADDLE_ENFORCE
(
product
(
ctx
->
GetInputDim
(
"Alpha"
))
==
x_dim
[
1
],
"For channel-wise mode, size of weight Alpha must be "
"equal to the number of channels, should be %d"
,
x_dim
[
1
]);
PADDLE_ENFORCE_EQ
(
product
(
ctx
->
GetInputDim
(
"Alpha"
)),
x_dim
[
1
],
platform
::
errors
::
InvalidArgument
(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d"
,
product
(
ctx
->
GetInputDim
(
"Alpha"
)),
x_dim
[
1
]));
}
else
if
(
mode
==
"element"
)
{
auto
alpha_dim
=
ctx
->
GetInputDim
(
"Alpha"
);
auto
alpha_rank
=
alpha_dim
.
size
();
auto
x_rank
=
x_dim
.
size
();
PADDLE_ENFORCE_EQ
(
alpha_rank
,
x_rank
,
platform
::
errors
::
InvalidArgument
(
"For mode 'element', rank of weight Alpha must be "
,
"equal to the rank of input(x). But recevied alpha's rank: %d, "
"x's rank: %d."
,
alpha_rank
,
x_rank
));
size_t
x_product
=
1
;
size_t
alpha_product
=
1
;
PADDLE_ENFORCE_EQ
(
alpha_rank
,
x_rank
,
"For element-wise mode, rank of weight Alpha must be "
,
"equal to the rank of input."
);
for
(
int64_t
i
=
x_rank
-
1
;
i
>
0
;
i
--
)
{
x_product
*=
x_dim
[
i
];
alpha_product
*=
alpha_dim
[
i
];
}
PADDLE_ENFORCE_EQ
(
x_product
,
alpha_product
,
"For element-wise mode, size of weight Alpha must be "
"equal to the number of input."
);
PADDLE_ENFORCE_EQ
(
alpha_product
,
x_product
,
platform
::
errors
::
InvalidArgument
(
"For mode 'element', the size of weight Alpha must be "
"equal to the size of input(x). But recevied alpha's size: %d, "
"x's size: %d."
,
alpha_product
,
x_product
));
}
else
{
PADDLE_THROW
(
"Unkown mode %s"
,
mode
);
}
...
...
@@ -108,9 +116,10 @@ class PReluGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"prelu"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
"Out@GRAD"
,
"prelu"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
alpha_grad_name
=
framework
::
GradVarName
(
"Alpha"
);
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
5b3dd806
...
...
@@ -9132,6 +9132,8 @@ def prelu(x, mode, param_attr=None, name=None):
x,mode,param_attr=ParamAttr(name='alpha'))
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'prelu')
helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.')
...
...
python/paddle/fluid/tests/unittests/test_prelu_op.py
浏览文件 @
5b3dd806
...
...
@@ -17,6 +17,8 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
six
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
from
op_test
import
OpTest
,
skip_check_grad_ci
...
...
@@ -80,5 +82,19 @@ if six.PY2:
self
.
attrs
=
{
'mode'
:
"element"
}
class
TestPReluOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
prelu
,
1
,
'all'
)
# The input dtype must be float16, float32, float64.
x_int32
=
fluid
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
prelu
,
x_int32
,
'all'
)
# support the input dtype is float32
x_fp16
=
fluid
.
layers
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float32'
)
fluid
.
layers
.
prelu
(
x_fp16
,
'all'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录