未验证 提交 5b3dd806 编写于 作者: Z zhupengyang 提交者: GitHub

Op(prelu) error message enhancement (#23616)

上级 0581d74d
...@@ -24,40 +24,48 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -24,40 +24,48 @@ class PReluOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
std::string mode = ctx->Attrs().Get<std::string>("mode"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu");
OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "prelu");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "prelu");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE(ctx->HasInput("X"), std::string mode = ctx->Attrs().Get<std::string>("mode");
"Input(X) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("Alpha"),
"Input(Alpha) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PreluOp should not be null");
if (mode == "all") { if (mode == "all") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, PADDLE_ENFORCE_EQ(
"For mode 'all', size of weight Alpha must be one."); product(ctx->GetInputDim("Alpha")), 1,
platform::errors::InvalidArgument(
"For mode 'all', size of weight Alpha must be one."));
} else if (mode == "channel") { } else if (mode == "channel") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1], PADDLE_ENFORCE_EQ(product(ctx->GetInputDim("Alpha")), x_dim[1],
"For channel-wise mode, size of weight Alpha must be " platform::errors::InvalidArgument(
"equal to the number of channels, should be %d", "For mode 'channel', size of weight Alpha must be "
x_dim[1]); "equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
} else if (mode == "element") { } else if (mode == "element") {
auto alpha_dim = ctx->GetInputDim("Alpha"); auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size(); auto alpha_rank = alpha_dim.size();
auto x_rank = x_dim.size(); auto x_rank = x_dim.size();
PADDLE_ENFORCE_EQ(
alpha_rank, x_rank,
platform::errors::InvalidArgument(
"For mode 'element', rank of weight Alpha must be ",
"equal to the rank of input(x). But recevied alpha's rank: %d, "
"x's rank: %d.",
alpha_rank, x_rank));
size_t x_product = 1; size_t x_product = 1;
size_t alpha_product = 1; size_t alpha_product = 1;
PADDLE_ENFORCE_EQ(alpha_rank, x_rank,
"For element-wise mode, rank of weight Alpha must be ",
"equal to the rank of input.");
for (int64_t i = x_rank - 1; i > 0; i--) { for (int64_t i = x_rank - 1; i > 0; i--) {
x_product *= x_dim[i]; x_product *= x_dim[i];
alpha_product *= alpha_dim[i]; alpha_product *= alpha_dim[i];
} }
PADDLE_ENFORCE_EQ(x_product, alpha_product, PADDLE_ENFORCE_EQ(
"For element-wise mode, size of weight Alpha must be " alpha_product, x_product,
"equal to the number of input."); platform::errors::InvalidArgument(
"For mode 'element', the size of weight Alpha must be "
"equal to the size of input(x). But recevied alpha's size: %d, "
"x's size: %d.",
alpha_product, x_product));
} else { } else {
PADDLE_THROW("Unkown mode %s", mode); PADDLE_THROW("Unkown mode %s", mode);
} }
...@@ -108,9 +116,10 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -108,9 +116,10 @@ class PReluGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Input(Out@GRAD) should not be null"); "Out@GRAD", "prelu");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto alpha_grad_name = framework::GradVarName("Alpha"); auto alpha_grad_name = framework::GradVarName("Alpha");
......
...@@ -9132,6 +9132,8 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -9132,6 +9132,8 @@ def prelu(x, mode, param_attr=None, name=None):
x,mode,param_attr=ParamAttr(name='alpha')) x,mode,param_attr=ParamAttr(name='alpha'))
""" """
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'prelu')
helper = LayerHelper('prelu', **locals()) helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']: if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.') raise ValueError('mode should be one of all, channel, element.')
......
...@@ -17,6 +17,8 @@ from __future__ import print_function ...@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import six import six
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci
...@@ -80,5 +82,19 @@ if six.PY2: ...@@ -80,5 +82,19 @@ if six.PY2:
self.attrs = {'mode': "element"} self.attrs = {'mode': "element"}
class TestPReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.prelu, 1, 'all')
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.prelu, x_int32, 'all')
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.prelu(x_fp16, 'all')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册