未验证 提交 5b3dd806 编写于 作者: Z zhupengyang 提交者: GitHub

Op(prelu) error message enhancement (#23616)

上级 0581d74d
......@@ -24,40 +24,48 @@ class PReluOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
std::string mode = ctx->Attrs().Get<std::string>("mode");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu");
OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "prelu");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "prelu");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("Alpha"),
"Input(Alpha) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PreluOp should not be null");
std::string mode = ctx->Attrs().Get<std::string>("mode");
if (mode == "all") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
"For mode 'all', size of weight Alpha must be one.");
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")), 1,
platform::errors::InvalidArgument(
"For mode 'all', size of weight Alpha must be one."));
} else if (mode == "channel") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1],
"For channel-wise mode, size of weight Alpha must be "
"equal to the number of channels, should be %d",
x_dim[1]);
PADDLE_ENFORCE_EQ(product(ctx->GetInputDim("Alpha")), x_dim[1],
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
} else if (mode == "element") {
auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size();
auto x_rank = x_dim.size();
PADDLE_ENFORCE_EQ(
alpha_rank, x_rank,
platform::errors::InvalidArgument(
"For mode 'element', rank of weight Alpha must be ",
"equal to the rank of input(x). But recevied alpha's rank: %d, "
"x's rank: %d.",
alpha_rank, x_rank));
size_t x_product = 1;
size_t alpha_product = 1;
PADDLE_ENFORCE_EQ(alpha_rank, x_rank,
"For element-wise mode, rank of weight Alpha must be ",
"equal to the rank of input.");
for (int64_t i = x_rank - 1; i > 0; i--) {
x_product *= x_dim[i];
alpha_product *= alpha_dim[i];
}
PADDLE_ENFORCE_EQ(x_product, alpha_product,
"For element-wise mode, size of weight Alpha must be "
"equal to the number of input.");
PADDLE_ENFORCE_EQ(
alpha_product, x_product,
platform::errors::InvalidArgument(
"For mode 'element', the size of weight Alpha must be "
"equal to the size of input(x). But recevied alpha's size: %d, "
"x's size: %d.",
alpha_product, x_product));
} else {
PADDLE_THROW("Unkown mode %s", mode);
}
......@@ -108,9 +116,10 @@ class PReluGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "prelu");
auto x_grad_name = framework::GradVarName("X");
auto alpha_grad_name = framework::GradVarName("Alpha");
......
......@@ -9132,6 +9132,8 @@ def prelu(x, mode, param_attr=None, name=None):
x,mode,param_attr=ParamAttr(name='alpha'))
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'prelu')
helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.')
......
......@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest
import numpy as np
import six
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci
......@@ -80,5 +82,19 @@ if six.PY2:
self.attrs = {'mode': "element"}
class TestPReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.prelu, 1, 'all')
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.prelu, x_int32, 'all')
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.prelu(x_fp16, 'all')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册