未验证 提交 417b576c 编写于 作者: L liu zhengxi 提交者: GitHub

API(dynamic_lstm, dynamic_lstmp) error message enhancement (#24450)

* update err msg for dynamic_lstm and dynamic_lstmp, test=develop
上级 53bdee64
...@@ -24,64 +24,80 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -24,64 +24,80 @@ class LSTMOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM");
"Input(Input) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM");
PADDLE_ENFORCE(ctx->HasInput("Weight"), OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM");
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM");
"Input(Bias) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM");
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
"Output(Hidden) of LSTM should not be null."); "BatchCellPreAct", "LSTM");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTM should not be null.");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(
in_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be 2, but received %d.", in_dims.size()));
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE_EQ(
"Input(Cell) and Input(Hidden) of LSTM should not " ctx->HasInput("C0"), true,
"be null at the same time."); platform::errors::NotFound("Input(Cell) and Input(Hidden) of LSTM "
"should not be null at the same time."));
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE_EQ(h_dims, c_dims,
"The dimension of Input(H0) and Input(C0) " platform::errors::InvalidArgument(
"should be the same."); "The dimension of Input(H0) and Input(C0) should "
"be the same, but received [%s] (H0) vs [%s] (C0).",
h_dims, c_dims));
} }
int frame_size = in_dims[1] / 4; int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight"); auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2, PADDLE_ENFORCE_EQ(
"The rank of Input(Weight) should be 2."); w_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(Weight) should be 2, but received %d.",
w_dims.size()));
PADDLE_ENFORCE_EQ(w_dims[0], frame_size, PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
"The first dimension of Input(Weight) " platform::errors::InvalidArgument(
"should be %d.", "The first dimension of Input(Weight) should be %d, "
frame_size); "but received %d.",
frame_size, w_dims[0]));
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
"The second dimension of Input(Weight) " platform::errors::InvalidArgument(
"should be 4 * %d.", "The second dimension of Input(Weight) should be 4 * "
frame_size); "%d, but received %d.",
frame_size, w_dims[1]));
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(b_dims[0], 1, b_dims.size(), 2,
"The first dimension of Input(Bias) should be 1."); platform::errors::InvalidArgument(
"The rank of Input(Bias) should be 2, but received %d.",
b_dims.size()));
PADDLE_ENFORCE_EQ(
b_dims[0], 1,
platform::errors::InvalidArgument(
"The first dimension of Input(Bias) should be 1, but received %d.",
b_dims[0]));
if (ctx->Attrs().Get<bool>("use_peepholes")) { if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, PADDLE_ENFORCE_EQ(
"The second dimension of Input(Bias) should be " b_dims[1], 7 * frame_size,
"7 * %d if enable peepholes connection", platform::errors::InvalidArgument(
frame_size); "The second dimension of Input(Bias) should be 7 * %d if enable "
"peepholes connection, but received %d.",
frame_size, b_dims[1]));
} else { } else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(
"The second dimension of Input(Bias) should be " b_dims[1], 4 * frame_size,
"4 * %d if disable peepholes connection", platform::errors::InvalidArgument(
frame_size); "The second dimension of Input(Bias) should be 4 * %d if disable "
"peepholes connection, but received %d.",
frame_size, b_dims[1]));
} }
framework::DDim out_dims({in_dims[0], frame_size}); framework::DDim out_dims({in_dims[0], frame_size});
...@@ -229,21 +245,16 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -229,21 +245,16 @@ class LSTMGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM@Grad");
"Input(Input) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "LSTM@Grad");
PADDLE_ENFORCE(ctx->HasInput("Hidden"), OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTM@Grad");
"Input(Hidden) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM@Grad");
PADDLE_ENFORCE(ctx->HasInput("Cell"), OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM@Grad");
"Input(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"), OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
"Input(Weight) of LSTM should not be null."); "LSTM@Grad");
PADDLE_ENFORCE(ctx->HasInput("Bias"), OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"), "Input", "BatchCellPreAct",
"Input(Bias) of LSTM should not be null."); "LSTM@Grad");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null.");
auto SetOutGradDim = [&ctx](const std::string& name) { auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name); auto g_name = framework::GradVarName(name);
......
...@@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTMP");
"Input(Input) of LSTMP operator should not be null."); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTMP");
PADDLE_ENFORCE(ctx->HasInput("Weight"), OP_INOUT_CHECK(ctx->HasInput("ProjWeight"), "Input", "ProjWeight", "LSTMP");
"Input(Weight) of LSTMP operator should not be null."); OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTMP");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP operator should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Projection"), "Output", "Projection",
PADDLE_ENFORCE(ctx->HasInput("Bias"), "LSTMP");
"Input(Bias) of LSTMP operator should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTMP");
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTMP");
PADDLE_ENFORCE(ctx->HasOutput("Projection"), OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
"Output(Projection) of LSTMP operator should not be null."); "BatchCellPreAct", "LSTMP");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
"Output(Cell) of LSTMP operator should not be null."); "LSTMP");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchCellPreAct) of LSTMP operator should not be "
"null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(BatchHidden) of LSTMP operator should not be null.");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, PADDLE_ENFORCE_EQ(
"Input(X)'s rank of LSTMP operator must be 2."); in_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank of LSTMP operator must be 2, but received %d.",
in_dims.size()));
int frame_size = in_dims[1] / 4; int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight"); auto w_dims = ctx->GetInputDim("Weight");
auto proj_dims = ctx->GetInputDim("ProjWeight"); auto proj_dims = ctx->GetInputDim("ProjWeight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2, PADDLE_ENFORCE_EQ(
"The rank of Input(Weight) should be 2."); w_dims.size(), 2,
PADDLE_ENFORCE_EQ(w_dims[0], proj_dims[1], platform::errors::InvalidArgument(
"The first dimension of Input(Weight) " "The rank of Input(Weight) should be 2, but received %d.",
"should be %d.", w_dims.size()));
proj_dims[1]); PADDLE_ENFORCE_EQ(
w_dims[0], proj_dims[1],
platform::errors::InvalidArgument(
"The first dimension of Input(Weight) and the second dimension of "
"Input(ProjWeight) should be the same, but received %d vs %d.",
w_dims[0], proj_dims[1]));
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
"The second dimension of Input(Weight) " platform::errors::InvalidArgument(
"should be 4 * %d.", "The second dimension of Input(Weight) should be 4 * "
frame_size); "%d, but received %d.",
frame_size, w_dims[1]));
PADDLE_ENFORCE_EQ(proj_dims.size(), 2,
"The rank of Input(ProjWeight) should be 2."); PADDLE_ENFORCE_EQ(
proj_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(ProjWeight) should be 2, but received %d.",
proj_dims.size()));
PADDLE_ENFORCE_EQ(proj_dims[0], frame_size, PADDLE_ENFORCE_EQ(proj_dims[0], frame_size,
"The first dimension of Input(ProjWeight) " platform::errors::InvalidArgument(
"should be %d.", "The first dimension of Input(ProjWeight) should be "
frame_size); "%d, but received %d.",
frame_size, proj_dims[0]));
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE_EQ(
"Input(C0) of LSTMP operator should not be null after " ctx->HasInput("C0"), true,
"Input(H0) provided."); platform::errors::NotFound("Input(C0) of LSTMP operator should not "
"be null after Input(H0) provided."));
} }
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(b_dims[0], 1, b_dims.size(), 2,
"The first dimension of Input(Bias) should be 1."); platform::errors::InvalidArgument(
"The rank of Input(Bias) should be 2, but received %d.",
b_dims.size()));
PADDLE_ENFORCE_EQ(
b_dims[0], 1,
platform::errors::InvalidArgument(
"The first dimension of Input(Bias) should be 1, but received %d.",
b_dims[0]));
if (ctx->Attrs().Get<bool>("use_peepholes")) { if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, PADDLE_ENFORCE_EQ(
"The second dimension of Input(Bias) should be " b_dims[1], 7 * frame_size,
"7 * %d if enable peepholes connection", platform::errors::InvalidArgument(
frame_size); "The second dimension of Input(Bias) should be 7 * %d if enable "
"peepholes connection, but received %d.",
frame_size, b_dims[1]));
} else { } else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(
"The second dimension of Input(Bias) should be " b_dims[1], 4 * frame_size,
"4 * %d if disable peepholes connection", platform::errors::InvalidArgument(
frame_size); "The second dimension of Input(Bias) should be 4 * %d if disable "
"peepholes connection, but received %d.",
frame_size, b_dims[1]));
} }
framework::DDim out_dims({in_dims[0], frame_size}); framework::DDim out_dims({in_dims[0], frame_size});
...@@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Projection"), OP_INOUT_CHECK(ctx->HasInput("Projection"), "Input", "Projection",
"Input(Projection) of LSTMP operator should not be null."); "LSTMP@Grad");
PADDLE_ENFORCE(ctx->HasInput("Cell"), OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTMP@Grad");
"Input(Cell) of LSTMP operator should not be null."); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTMP@Grad");
PADDLE_ENFORCE(ctx->HasInput("Weight"), OP_INOUT_CHECK(ctx->HasInput("ProjWeight"), "Input", "ProjWeight",
"Input(Weight) of LSTMP operator should not be null."); "LSTMP@Grad");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"), OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTMP@Grad");
"Input(ProjWeight) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
"Input(Bias) of LSTMP operator should not be null."); "LSTMP@Grad");
OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"), "Input", "BatchCellPreAct",
PADDLE_ENFORCE(ctx->HasInput("BatchGate"), "LSTMP@Grad");
"Input(BatchGate) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTMP operator should not be null.");
auto SetOutGradDim = [&ctx](const std::string& name) { auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name); auto g_name = framework::GradVarName(name);
......
...@@ -2073,7 +2073,21 @@ def dynamic_lstm(input, ...@@ -2073,7 +2073,21 @@ def dynamic_lstm(input,
""" """
assert in_dygraph_mode( assert in_dygraph_mode(
) is not True, "please use lstm instead of dynamic_lstm in dygraph mode!" ) is not True, "please use lstm instead of dynamic_lstm in dygraph mode!"
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp." assert bias_attr is not False, "bias_attr should not be False in dynamic_lstm."
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'dynamic_lstm')
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_lstm')
if isinstance(h_0, Variable):
check_variable_and_dtype(h_0, 'h_0', ['float32', 'float64'],
'dynamic_lstm')
check_type(c_0, 'c_0', (Variable, type(None)), 'dynamic_lstm')
if isinstance(c_0, Variable):
check_variable_and_dtype(c_0, 'c_0', ['float32', 'float64'],
'dynamic_lstm')
helper = LayerHelper('lstm', **locals()) helper = LayerHelper('lstm', **locals())
size = size // 4 size = size // 4
weight = helper.create_parameter( weight = helper.create_parameter(
...@@ -2439,6 +2453,20 @@ def dynamic_lstmp(input, ...@@ -2439,6 +2453,20 @@ def dynamic_lstmp(input,
) is not True, "please use lstm instead of dynamic_lstmp in dygraph mode!" ) is not True, "please use lstm instead of dynamic_lstmp in dygraph mode!"
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp." assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'dynamic_lstmp')
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_lstmp')
if isinstance(h_0, Variable):
check_variable_and_dtype(h_0, 'h_0', ['float32', 'float64'],
'dynamic_lstmp')
check_type(c_0, 'c_0', (Variable, type(None)), 'dynamic_lstmp')
if isinstance(c_0, Variable):
check_variable_and_dtype(c_0, 'c_0', ['float32', 'float64'],
'dynamic_lstmp')
helper = LayerHelper('lstmp', **locals()) helper = LayerHelper('lstmp', **locals())
size = size // 4 size = size // 4
weight = helper.create_parameter( weight = helper.create_parameter(
......
...@@ -301,6 +301,42 @@ class TestLstmOpCase3(TestLstmOp): ...@@ -301,6 +301,42 @@ class TestLstmOpCase3(TestLstmOp):
self.lod = [[2, 0, 4]] self.lod = [[2, 0, 4]]
class TestLstmOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
input_data = np.random.random((1, 2048)).astype("float32")
fluid.layers.dynamic_lstm(
input=input_data, size=2048, use_peepholes=False)
self.assertRaises(TypeError, test_Variable)
def test_h_0():
in_data = fluid.data(
name="input", shape=[None, 2048], dtype="float32")
h = fluid.data(name="h", shape=[None, 512], dtype="int32")
c = fluid.data(name="c", shape=[None, 512], dtype="float32")
fluid.layers.dynamic_lstm(
input=in_data, size=2048, use_peepholes=False, h_0=h, c_0=c)
self.assertRaises(TypeError, test_h_0)
def test_c_0():
in_data_ = fluid.data(
name="input_", shape=[None, 2048], dtype="float32")
h_ = fluid.data(name="h_", shape=[None, 512], dtype="float32")
c_ = fluid.data(name="c_", shape=[None, 512], dtype="int32")
fluid.layers.dynamic_lstm(
input=in_data_,
size=2048,
use_peepholes=False,
h_0=h_,
c_0=c_)
self.assertRaises(TypeError, test_c_0)
# class TestLstmOpHasInitial(TestLstmOp): # class TestLstmOpHasInitial(TestLstmOp):
# def set_argument(self): # def set_argument(self):
# self.lod = [[2, 3, 2]] # self.lod = [[2, 3, 2]]
......
...@@ -16,6 +16,8 @@ from __future__ import print_function ...@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import test_lstm_op as LstmTest import test_lstm_op as LstmTest
from paddle import fluid
from paddle.fluid import Program, program_guard
ACTIVATION = { ACTIVATION = {
'identity': LstmTest.identity, 'identity': LstmTest.identity,
...@@ -315,5 +317,59 @@ class TestLstmpOpLen0Case2(TestLstmpOp): ...@@ -315,5 +317,59 @@ class TestLstmpOpLen0Case2(TestLstmpOp):
self.lod = [[2, 0, 3]] self.lod = [[2, 0, 3]]
class TestLstmpOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
input_data = np.random.random((1, 2048)).astype("float32")
fluid.layers.dynamic_lstmp(
input=input_data,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh")
self.assertRaises(TypeError, test_Variable)
def test_h_0():
in_data = fluid.data(
name="input", shape=[None, 2048], dtype="float32")
h = fluid.data(name="h", shape=[None, 512], dtype="int32")
c = fluid.data(name="c", shape=[None, 512], dtype="float32")
fluid.layers.dynamic_lstmp(
input=in_data,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh",
h_0=h,
c_0=c)
self.assertRaises(TypeError, test_h_0)
def test_c_0():
in_data_ = fluid.data(
name="input_", shape=[None, 2048], dtype="float32")
h_ = fluid.data(name="h_", shape=[None, 512], dtype="float32")
c_ = fluid.data(name="c_", shape=[None, 512], dtype="int32")
fluid.layers.dynamic_lstmp(
input=in_data_,
size=2048,
proj_size=256,
use_peepholes=False,
is_reverse=True,
cell_activation="tanh",
proj_activation="tanh",
h_0=h_,
c_0=c_)
self.assertRaises(TypeError, test_c_0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册