未验证 提交 d49c3061 编写于 作者: L liu zhengxi 提交者: GitHub

API(dynamic_lstm, dynamic_lstmp) error message enhancement (#24450) (#24512)

* update err msg for dynamic_lstm and dynamic_lstmp, test=develop
上级 008857be
......@@ -24,64 +24,80 @@ class LSTMOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTM should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
"BatchCellPreAct", "LSTM");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(
in_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be 2, but received %d.", in_dims.size()));
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("C0"), true,
platform::errors::NotFound("Input(Cell) and Input(Hidden) of LSTM "
"should not be null at the same time."));
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
PADDLE_ENFORCE_EQ(h_dims, c_dims,
platform::errors::InvalidArgument(
"The dimension of Input(H0) and Input(C0) should "
"be the same, but received [%s] (H0) vs [%s] (C0).",
h_dims, c_dims));
}
int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2.");
PADDLE_ENFORCE_EQ(
w_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(Weight) should be 2, but received %d.",
w_dims.size()));
PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
"The first dimension of Input(Weight) "
"should be %d.",
frame_size);
platform::errors::InvalidArgument(
"The first dimension of Input(Weight) should be %d, "
"but received %d.",
frame_size, w_dims[0]));
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
"The second dimension of Input(Weight) "
"should be 4 * %d.",
frame_size);
platform::errors::InvalidArgument(
"The second dimension of Input(Weight) should be 4 * "
"%d, but received %d.",
frame_size, w_dims[1]));
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ(
b_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(Bias) should be 2, but received %d.",
b_dims.size()));
PADDLE_ENFORCE_EQ(
b_dims[0], 1,
platform::errors::InvalidArgument(
"The first dimension of Input(Bias) should be 1, but received %d.",
b_dims[0]));
if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
frame_size);
PADDLE_ENFORCE_EQ(
b_dims[1], 7 * frame_size,
platform::errors::InvalidArgument(
"The second dimension of Input(Bias) should be 7 * %d if enable "
"peepholes connection, but received %d.",
frame_size, b_dims[1]));
} else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection",
frame_size);
PADDLE_ENFORCE_EQ(
b_dims[1], 4 * frame_size,
platform::errors::InvalidArgument(
"The second dimension of Input(Bias) should be 4 * %d if disable "
"peepholes connection, but received %d.",
frame_size, b_dims[1]));
}
framework::DDim out_dims({in_dims[0], frame_size});
......@@ -229,21 +245,16 @@ class LSTMGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM@Grad");
OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "LSTM@Grad");
OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTM@Grad");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM@Grad");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM@Grad");
OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
"LSTM@Grad");
OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"), "Input", "BatchCellPreAct",
"LSTM@Grad");
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
......
......@@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Projection"),
"Output(Projection) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchCellPreAct) of LSTMP operator should not be "
"null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(BatchHidden) of LSTMP operator should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTMP");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTMP");
OP_INOUT_CHECK(ctx->HasInput("ProjWeight"), "Input", "ProjWeight", "LSTMP");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTMP");
OP_INOUT_CHECK(ctx->HasOutput("Projection"), "Output", "Projection",
"LSTMP");
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTMP");
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTMP");
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
"BatchCellPreAct", "LSTMP");
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
"LSTMP");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2,
"Input(X)'s rank of LSTMP operator must be 2.");
PADDLE_ENFORCE_EQ(
in_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank of LSTMP operator must be 2, but received %d.",
in_dims.size()));
int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
auto proj_dims = ctx->GetInputDim("ProjWeight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], proj_dims[1],
"The first dimension of Input(Weight) "
"should be %d.",
proj_dims[1]);
PADDLE_ENFORCE_EQ(
w_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(Weight) should be 2, but received %d.",
w_dims.size()));
PADDLE_ENFORCE_EQ(
w_dims[0], proj_dims[1],
platform::errors::InvalidArgument(
"The first dimension of Input(Weight) and the second dimension of "
"Input(ProjWeight) should be the same, but received %d vs %d.",
w_dims[0], proj_dims[1]));
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
"The second dimension of Input(Weight) "
"should be 4 * %d.",
frame_size);
PADDLE_ENFORCE_EQ(proj_dims.size(), 2,
"The rank of Input(ProjWeight) should be 2.");
platform::errors::InvalidArgument(
"The second dimension of Input(Weight) should be 4 * "
"%d, but received %d.",
frame_size, w_dims[1]));
PADDLE_ENFORCE_EQ(
proj_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(ProjWeight) should be 2, but received %d.",
proj_dims.size()));
PADDLE_ENFORCE_EQ(proj_dims[0], frame_size,
"The first dimension of Input(ProjWeight) "
"should be %d.",
frame_size);
platform::errors::InvalidArgument(
"The first dimension of Input(ProjWeight) should be "
"%d, but received %d.",
frame_size, proj_dims[0]));
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of LSTMP operator should not be null after "
"Input(H0) provided.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("C0"), true,
platform::errors::NotFound("Input(C0) of LSTMP operator should not "
"be null after Input(H0) provided."));
}
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ(
b_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(Bias) should be 2, but received %d.",
b_dims.size()));
PADDLE_ENFORCE_EQ(
b_dims[0], 1,
platform::errors::InvalidArgument(
"The first dimension of Input(Bias) should be 1, but received %d.",
b_dims[0]));
if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
frame_size);
PADDLE_ENFORCE_EQ(
b_dims[1], 7 * frame_size,
platform::errors::InvalidArgument(
"The second dimension of Input(Bias) should be 7 * %d if enable "
"peepholes connection, but received %d.",
frame_size, b_dims[1]));
} else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection",
frame_size);
PADDLE_ENFORCE_EQ(
b_dims[1], 4 * frame_size,
platform::errors::InvalidArgument(
"The second dimension of Input(Bias) should be 4 * %d if disable "
"peepholes connection, but received %d.",
frame_size, b_dims[1]));
}
framework::DDim out_dims({in_dims[0], frame_size});
......@@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Projection"),
"Input(Projection) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTMP operator should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Projection"), "Input", "Projection",
"LSTMP@Grad");
OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTMP@Grad");
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTMP@Grad");
OP_INOUT_CHECK(ctx->HasInput("ProjWeight"), "Input", "ProjWeight",
"LSTMP@Grad");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTMP@Grad");
OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
"LSTMP@Grad");
OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"), "Input", "BatchCellPreAct",
"LSTMP@Grad");
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
......
......@@ -2030,7 +2030,21 @@ def dynamic_lstm(input,
"""
assert in_dygraph_mode(
) is not True, "please use lstm instead of dynamic_lstm in dygraph mode!"
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstm."
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'dynamic_lstm')
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_lstm')
if isinstance(h_0, Variable):
check_variable_and_dtype(h_0, 'h_0', ['float32', 'float64'],
'dynamic_lstm')
check_type(c_0, 'c_0', (Variable, type(None)), 'dynamic_lstm')
if isinstance(c_0, Variable):
check_variable_and_dtype(c_0, 'c_0', ['float32', 'float64'],
'dynamic_lstm')
helper = LayerHelper('lstm', **locals())
size = size // 4
weight = helper.create_parameter(
......@@ -2396,6 +2410,20 @@ def dynamic_lstmp(input,
) is not True, "please use lstm instead of dynamic_lstmp in dygraph mode!"
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'dynamic_lstmp')
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_lstmp')
if isinstance(h_0, Variable):
check_variable_and_dtype(h_0, 'h_0', ['float32', 'float64'],
'dynamic_lstmp')
check_type(c_0, 'c_0', (Variable, type(None)), 'dynamic_lstmp')
if isinstance(c_0, Variable):
check_variable_and_dtype(c_0, 'c_0', ['float32', 'float64'],
'dynamic_lstmp')
helper = LayerHelper('lstmp', **locals())
size = size // 4
weight = helper.create_parameter(
......
......@@ -301,6 +301,42 @@ class TestLstmOpCase3(TestLstmOp):
self.lod = [[2, 0, 4]]
class TestLstmOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
input_data = np.random.random((1, 2048)).astype("float32")
fluid.layers.dynamic_lstm(
input=input_data, size=2048, use_peepholes=False)
self.assertRaises(TypeError, test_Variable)
def test_h_0():
in_data = fluid.data(
name="input", shape=[None, 2048], dtype="float32")
h = fluid.data(name="h", shape=[None, 512], dtype="int32")
c = fluid.data(name="c", shape=[None, 512], dtype="float32")
fluid.layers.dynamic_lstm(
input=in_data, size=2048, use_peepholes=False, h_0=h, c_0=c)
self.assertRaises(TypeError, test_h_0)
def test_c_0():
in_data_ = fluid.data(
name="input_", shape=[None, 2048], dtype="float32")
h_ = fluid.data(name="h_", shape=[None, 512], dtype="float32")
c_ = fluid.data(name="c_", shape=[None, 512], dtype="int32")
fluid.layers.dynamic_lstm(
input=in_data_,
size=2048,
use_peepholes=False,
h_0=h_,
c_0=c_)
self.assertRaises(TypeError, test_c_0)
# class TestLstmOpHasInitial(TestLstmOp):
# def set_argument(self):
# self.lod = [[2, 3, 2]]
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
import test_lstm_op as LstmTest
from paddle import fluid
from paddle.fluid import Program, program_guard
ACTIVATION = {
'identity': LstmTest.identity,
......@@ -315,5 +317,59 @@ class TestLstmpOpLen0Case2(TestLstmpOp):
self.lod = [[2, 0, 3]]
class TestLstmpOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
input_data = np.random.random((1, 2048)).astype("float32")
fluid.layers.dynamic_lstmp(
input=input_data,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh")
self.assertRaises(TypeError, test_Variable)
def test_h_0():
in_data = fluid.data(
name="input", shape=[None, 2048], dtype="float32")
h = fluid.data(name="h", shape=[None, 512], dtype="int32")
c = fluid.data(name="c", shape=[None, 512], dtype="float32")
fluid.layers.dynamic_lstmp(
input=in_data,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh",
h_0=h,
c_0=c)
self.assertRaises(TypeError, test_h_0)
def test_c_0():
in_data_ = fluid.data(
name="input_", shape=[None, 2048], dtype="float32")
h_ = fluid.data(name="h_", shape=[None, 512], dtype="float32")
c_ = fluid.data(name="c_", shape=[None, 512], dtype="int32")
fluid.layers.dynamic_lstmp(
input=in_data_,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh",
h_0=h_,
c_0=c_)
self.assertRaises(TypeError, test_c_0)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册