未验证 提交 fff9faae 编写于 作者: L liu zhengxi 提交者: GitHub

API(dynamic_gru, chunk_eval, BeamSearchDecoder) error message enhancement (#24513)

* dynamic_gru err_msg enhancement, test=develop

* chunk_eval err_msg enhancement and fix crf_decoding output type, test=develop

* BeamSearchDecoder err msg enhancement, test=develop

* fix doc for chunk_eval, test=develop

* refine lod err msg for chunk_eval, test=develop
上级 5ff45357
...@@ -24,45 +24,48 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ...@@ -24,45 +24,48 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Inference"), true, OP_INOUT_CHECK(ctx->HasInput("Inference"), "Input", "Inference",
"Input(Inference) of ChunkEvalOp should not be null."); "chunk_eval");
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true, OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "chunk_eval");
"Input(Label) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Precision"), true, OP_INOUT_CHECK(ctx->HasOutput("Precision"), "Output", "Precision",
"Output(Precision) of ChunkEvalOp should not be null."); "chunk_eval");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Recall"), true, OP_INOUT_CHECK(ctx->HasOutput("Recall"), "Output", "Recall", "chunk_eval");
"Output(Recall) of ChunkEvalOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("F1-Score"), "Output", "F1-Score",
PADDLE_ENFORCE_EQ(ctx->HasOutput("F1-Score"), true, "chunk_eval");
"Output(F1-Score) of ChunkEvalOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("NumInferChunks"), "Output", "NumInferChunks",
PADDLE_ENFORCE_EQ( "chunk_eval");
ctx->HasOutput("NumInferChunks"), true, OP_INOUT_CHECK(ctx->HasOutput("NumLabelChunks"), "Output", "NumLabelChunks",
"Output(NumInferChunks) of ChunkEvalOp should not be null."); "chunk_eval");
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(ctx->HasOutput("NumCorrectChunks"), "Output",
ctx->HasOutput("NumLabelChunks"), true, "NumCorrectChunks", "chunk_eval");
"Output(NumLabelChunks) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("NumCorrectChunks"), true,
"Output(NumCorrectChunks) of ChunkEvalOp should not be null.");
auto inference_dim = ctx->GetInputDim("Inference"); auto inference_dim = ctx->GetInputDim("Inference");
auto label_dim = ctx->GetInputDim("Label"); auto label_dim = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
inference_dim, label_dim, inference_dim, label_dim,
"Input(Inference)'s shape must be the same as Input(Label)'s shape."); platform::errors::InvalidArgument(
"Input(Inference)'s shape must be the same as Input(Label)'s "
"shape, but received [%s] (Inference) vs [%s] (Label).",
inference_dim, label_dim));
bool use_padding = ctx->HasInput("SeqLength"); bool use_padding = ctx->HasInput("SeqLength");
if (use_padding) { if (use_padding) {
PADDLE_ENFORCE_EQ((inference_dim.size() == 3 && inference_dim[2] == 1) || PADDLE_ENFORCE_EQ(
inference_dim.size() == 2, (inference_dim.size() == 3 && inference_dim[2] == 1) ||
true, inference_dim.size() == 2,
"when Input(SeqLength) is provided, Input(Inference) " true, platform::errors::InvalidArgument(
"should be of dim 3 (batch_size, bucket, 1) or dim 2 " "when Input(SeqLength) is provided, Input(Inference) "
"(batch_size, bucket)."); "should be of dim 3 (batch_size, bucket, 1) or dim 2 "
"(batch_size, bucket), but received [%s].",
inference_dim));
auto seq_length_dim = ctx->GetInputDim("SeqLength"); auto seq_length_dim = ctx->GetInputDim("SeqLength");
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(seq_length_dim.size(), 2,
seq_length_dim.size(), 2, platform::errors::InvalidArgument(
"Input(SeqLength)'s rank should not be greater than 2."); "Input(SeqLength)'s rank should not be greater "
"than 2, but received %d.",
seq_length_dim.size()));
} }
ctx->SetOutputDim("Precision", {1}); ctx->SetOutputDim("Precision", {1});
......
...@@ -51,7 +51,13 @@ class ChunkEvalKernel : public framework::OpKernel<T> { ...@@ -51,7 +51,13 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
for (int i = 0; i < length; ++i) { for (int i = 0; i < length; ++i) {
int prev_tag = tag; int prev_tag = tag;
int prev_type = type; int prev_type = type;
PADDLE_ENFORCE_LE(label[i], num_chunk_types * num_tag_types); PADDLE_ENFORCE_LE(
label[i], num_chunk_types * num_tag_types,
platform::errors::InvalidArgument(
"The value of Input(Label) should be less than the number of "
"chunk types times the number of tag types, but received %d "
"(Label) vs %d (chunk types) * %d (tag types).",
label[i], num_chunk_types, num_tag_types));
tag = label[i] % num_tag_types; tag = label[i] % num_tag_types;
type = label[i] / num_tag_types; type = label[i] / num_tag_types;
if (in_chunk && ChunkEnd(prev_tag, prev_type, tag, type, other_chunk_type, if (in_chunk && ChunkEnd(prev_tag, prev_type, tag, type, other_chunk_type,
...@@ -191,10 +197,16 @@ class ChunkEvalKernel : public framework::OpKernel<T> { ...@@ -191,10 +197,16 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
tag_inside, tag_end, tag_single, excluded_chunk_types); tag_inside, tag_end, tag_single, excluded_chunk_types);
} }
} else { } else {
PADDLE_ENFORCE_EQ(lod.size(), 1UL, PADDLE_ENFORCE_EQ(
"Only support one level sequence now."); lod.size(), 1UL,
PADDLE_ENFORCE(lod == inference->lod(), platform::errors::InvalidArgument(
"LoD must be same between Inference and Label."); "Only support one level LoD sequence now, but received %d.",
lod.size()));
PADDLE_ENFORCE_EQ(
lod, inference->lod(),
platform::errors::InvalidArgument(
"Input(Inference) and Input(Label) of Op(chunk_eval) should have "
"same LoD information."));
num_sequences = lod[0].size() - 1; num_sequences = lod[0].size() - 1;
for (int i = 0; i < num_sequences; ++i) { for (int i = 0; i < num_sequences; ++i) {
......
...@@ -31,44 +31,58 @@ class GRUOp : public framework::OperatorWithKernel { ...@@ -31,44 +31,58 @@ class GRUOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU");
"Input(%s) of GRUOp should not be null.", "Input"); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU");
PADDLE_ENFORCE(ctx->HasInput("Weight"), OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU");
"Input(%s) of GRUOp should not be null.", "Weight"); OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output",
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), "BatchResetHiddenPrev", "GRU");
"Output(%s) of GRUOp should not be null.", "BatchGate"); OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"), "GRU");
"Output(%s) of GRUOp should not be null.", OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU");
"BatchResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(%s) of GRUOp should not be null.", "BatchHidden");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1]; int input_size = input_dims[1];
int frame_size = weight_dims[0]; int frame_size = weight_dims[0];
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
input_size, frame_size * 3, platform::errors::InvalidArgument(
"The input_size must be 3 times of frame_size in GRUOp."); "The second dimension of Input(Input) must be 3 "
"times of frame_size in GRUOp, but received %d "
"(Input) vs %d (frame_size).",
input_size, frame_size));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_dims[1], frame_size * 3, weight_dims[1], frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
weight_dims[0], weight_dims[1], frame_size, frame_size * 3));
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0"); auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, PADDLE_ENFORCE_EQ(
"The width of H0 must be equal to frame_size."); h0_dims[1], frame_size,
platform::errors::InvalidArgument(
"The width of Input(H0) must be equal to frame_size, but "
"received %d (width of H0) vs %d (frame_size).",
h0_dims[1], frame_size));
} }
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0]; int bias_height = bias_dims[0];
int bias_width = bias_dims[1]; int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1, PADDLE_ENFORCE_EQ(
"The shape of Bias must be [1, frame_size * 3]."); bias_height, 1,
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3]."); "The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
PADDLE_ENFORCE_EQ(
bias_width, frame_size * 3,
platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
} }
ctx->SetOutputDim("BatchGate", input_dims); ctx->SetOutputDim("BatchGate", input_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
...@@ -166,39 +180,50 @@ class GRUGradOp : public framework::OperatorWithKernel { ...@@ -166,39 +180,50 @@ class GRUGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU@Grad");
"Input(%s) of GRUGradOp should not be null.", "Input"); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU@Grad");
PADDLE_ENFORCE(ctx->HasInput("Weight"), OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
"Input(%s) of GRUGradOp should not be null.", "Weight"); "GRU@Grad");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"), OP_INOUT_CHECK(ctx->HasInput("BatchResetHiddenPrev"), "Input",
"Input(%s) of GRUGradOp should not be null.", "BatchGate"); "BatchResetHiddenPrev", "GRU@Grad");
PADDLE_ENFORCE(ctx->HasInput("BatchResetHiddenPrev"), OP_INOUT_CHECK(ctx->HasInput("BatchHidden"), "Input", "BatchHidden",
"Input(%s) of GRUGradOp should not be null.", "GRU@Grad");
"BatchResetHiddenPrev"); OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "GRU@Grad");
PADDLE_ENFORCE(ctx->HasInput("BatchHidden"), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), "Input",
"Input(%s) of GRUOp should not be null.", "BatchHidden"); framework::GradVarName("Hidden"), "GRU@Grad");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUGradOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1]; int input_size = input_dims[1];
int frame_size = weight_dims[0]; int frame_size = weight_dims[0];
int weight_height = weight_dims[0]; int weight_height = weight_dims[0];
int weight_width = weight_dims[1]; int weight_width = weight_dims[1];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3, PADDLE_ENFORCE_EQ(
"The input_size must be 3 times of frame_size in GRUOp."); input_size, frame_size * 3,
platform::errors::InvalidArgument(
"The second dimension of Input(Input) must be 3 times of "
"frame_size in GRUOp, but received %d (Input) vs %d (frame_size).",
input_size, frame_size));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_height, frame_size, weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3, weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Input(Weight) matrix must be [frame_size, frame_size "
"* 3], but received [%d, %d] (Weight) vs [%d, %d] (frame_size).",
weight_height, weight_width, frame_size, frame_size * 3));
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0"); auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, PADDLE_ENFORCE_EQ(
"The width of H0 must be equal to frame_size."); h0_dims[1], frame_size,
platform::errors::InvalidArgument(
"The width of Input(H0) must be equal to frame_size, but "
"received %d (width of H0) vs %d (frame_size).",
h0_dims[1], frame_size));
auto h0_grad_name = framework::GradVarName("H0"); auto h0_grad_name = framework::GradVarName("H0");
if (ctx->HasOutput(h0_grad_name)) if (ctx->HasOutput(h0_grad_name))
ctx->SetOutputDim(h0_grad_name, h0_dims); ctx->SetOutputDim(h0_grad_name, h0_dims);
...@@ -207,10 +232,18 @@ class GRUGradOp : public framework::OperatorWithKernel { ...@@ -207,10 +232,18 @@ class GRUGradOp : public framework::OperatorWithKernel {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0]; int bias_height = bias_dims[0];
int bias_width = bias_dims[1]; int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1, PADDLE_ENFORCE_EQ(
"The shape of Bias must be [1, frame_size * 3]."); bias_height, 1,
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3]."); "The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
PADDLE_ENFORCE_EQ(
bias_width, frame_size * 3,
platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but received "
"[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
bias_height, bias_width, frame_size * 3));
auto bias_grad_name = framework::GradVarName("Bias"); auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims); ctx->SetOutputDim(bias_grad_name, bias_dims);
...@@ -298,14 +331,20 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -298,14 +331,20 @@ class GRUCPUKernel : public framework::OpKernel<T> {
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size * 2 /*width of weight*/, frame_size * 2 /*width of weight*/,
frame_size /*height of height*/); frame_size /*height of height*/);
PADDLE_ENFORCE(packed_gate); PADDLE_ENFORCE_NOT_NULL(
packed_gate, platform::errors::NotFound(
"The caculation result of packed_gate by "
"GEMM_ALLOC should not be null when using MKL."));
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2, blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2, frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
packed_gate); packed_gate);
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/, T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size /*width of weight*/, frame_size /*width of weight*/,
frame_size /*height of height*/); frame_size /*height of height*/);
PADDLE_ENFORCE(packed_state); PADDLE_ENFORCE_NOT_NULL(
packed_state, platform::errors::NotFound(
"The caculation result of packed_state by "
"GEMM_ALLOC should not be null when using MKL."));
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size, blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
frame_size, T(1.0), gru_value.state_weight, frame_size, frame_size, T(1.0), gru_value.state_weight, frame_size,
packed_state); packed_state);
......
...@@ -219,7 +219,13 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -219,7 +219,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto in_dims = input->dims(); auto in_dims = input->dims();
auto out_dims = hidden_g->dims(); auto out_dims = hidden_g->dims();
int frame_size = static_cast<int>(in_dims[1] / 4); int frame_size = static_cast<int>(in_dims[1] / 4);
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); PADDLE_ENFORCE_EQ(
frame_size, out_dims[1],
platform::errors::InvalidArgument(
"The second dimension of Input(" +
framework::GradVarName("Hidden") +
") should be %d, but received %d in LSTM@Grad operator.",
frame_size, out_dims[1]));
math::LstmMetaValue<T> lstm_value; math::LstmMetaValue<T> lstm_value;
if (bias && ctx.Attr<bool>("use_peepholes")) { if (bias && ctx.Attr<bool>("use_peepholes")) {
......
...@@ -327,7 +327,11 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -327,7 +327,11 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto out_dims = cell_out->dims(); auto out_dims = cell_out->dims();
framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]}); framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]});
int frame_size = static_cast<int>(in_dims[1] / 4); int frame_size = static_cast<int>(in_dims[1] / 4);
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); PADDLE_ENFORCE_EQ(frame_size, out_dims[1],
platform::errors::InvalidArgument(
"The second dimension of Input(Cell) should be %d, "
"but received %d in LSTMP@Grad operator.",
frame_size, out_dims[1]));
math::LstmMetaValue<T> lstmp_value; math::LstmMetaValue<T> lstmp_value;
if (bias && ctx.Attr<bool>("use_peepholes")) { if (bias && ctx.Attr<bool>("use_peepholes")) {
......
...@@ -875,7 +875,7 @@ def crf_decoding(input, param_attr, label=None, length=None): ...@@ -875,7 +875,7 @@ def crf_decoding(input, param_attr, label=None, length=None):
helper = LayerHelper('crf_decoding', **locals()) helper = LayerHelper('crf_decoding', **locals())
transition = helper.get_parameter(param_attr.name) transition = helper.get_parameter(param_attr.name)
viterbi_path = helper.create_variable_for_type_inference( viterbi_path = helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) dtype=core.VarDesc.VarType.INT64)
inputs = {"Emission": [input], "Transition": transition, "Label": label} inputs = {"Emission": [input], "Transition": transition, "Label": label}
if length: if length:
inputs['Length'] = length inputs['Length'] = length
...@@ -1125,12 +1125,12 @@ def chunk_eval(input, ...@@ -1125,12 +1125,12 @@ def chunk_eval(input,
dict_size = 10000 dict_size = 10000
label_dict_len = 7 label_dict_len = 7
sequence = fluid.data( sequence = fluid.data(
name='id', shape=[-1, 1], lod_level=1, dtype='int64') name='id', shape=[None, 1], lod_level=1, dtype='int64')
embedding = fluid.embedding( embedding = fluid.embedding(
input=sequence, size=[dict_size, 512]) input=sequence, size=[dict_size, 512])
hidden = fluid.layers.fc(input=embedding, size=512) hidden = fluid.layers.fc(input=embedding, size=512)
label = fluid.layers.data( label = fluid.data(
name='label', shape=[1], lod_level=1, dtype='int32') name='label', shape=[None, 1], lod_level=1, dtype='int64')
crf = fluid.layers.linear_chain_crf( crf = fluid.layers.linear_chain_crf(
input=hidden, label=label, param_attr=fluid.ParamAttr(name="crfw")) input=hidden, label=label, param_attr=fluid.ParamAttr(name="crfw"))
crf_decode = fluid.layers.crf_decoding( crf_decode = fluid.layers.crf_decoding(
...@@ -1139,10 +1139,13 @@ def chunk_eval(input, ...@@ -1139,10 +1139,13 @@ def chunk_eval(input,
input=crf_decode, input=crf_decode,
label=label, label=label,
chunk_scheme="IOB", chunk_scheme="IOB",
num_chunk_types=(label_dict_len - 1) / 2) num_chunk_types=int((label_dict_len - 1) / 2))
""" """
helper = LayerHelper("chunk_eval", **locals()) helper = LayerHelper("chunk_eval", **locals())
check_variable_and_dtype(input, 'input', ['int64'], 'chunk_eval')
check_variable_and_dtype(label, 'label', ['int64'], 'chunk_eval')
# prepare output # prepare output
precision = helper.create_variable_for_type_inference(dtype="float32") precision = helper.create_variable_for_type_inference(dtype="float32")
recall = helper.create_variable_for_type_inference(dtype="float32") recall = helper.create_variable_for_type_inference(dtype="float32")
......
...@@ -790,6 +790,8 @@ class BeamSearchDecoder(Decoder): ...@@ -790,6 +790,8 @@ class BeamSearchDecoder(Decoder):
Variable: A tensor with shape `[batch_size * beam_size, ...]`, whose \ Variable: A tensor with shape `[batch_size * beam_size, ...]`, whose \
data type is same as `x`. data type is same as `x`.
""" """
check_type(x, 'x', (Variable),
'BeamSearchDecoder.tile_beam_merge_with_batch')
x = nn.unsqueeze(x, [1]) # [batch_size, 1, ...] x = nn.unsqueeze(x, [1]) # [batch_size, 1, ...]
expand_times = [1] * len(x.shape) expand_times = [1] * len(x.shape)
expand_times[1] = beam_size expand_times[1] = beam_size
...@@ -818,6 +820,7 @@ class BeamSearchDecoder(Decoder): ...@@ -818,6 +820,7 @@ class BeamSearchDecoder(Decoder):
Variable: A tensor with shape `[batch_size, beam_size, ...]`, whose \ Variable: A tensor with shape `[batch_size, beam_size, ...]`, whose \
data type is same as `x`. data type is same as `x`.
""" """
check_type(x, 'x', (Variable), 'BeamSearchDecoder._split_batch_beams')
# TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch
return nn.reshape(x, shape=[-1, self.beam_size] + list(x.shape[1:])) return nn.reshape(x, shape=[-1, self.beam_size] + list(x.shape[1:]))
...@@ -834,6 +837,7 @@ class BeamSearchDecoder(Decoder): ...@@ -834,6 +837,7 @@ class BeamSearchDecoder(Decoder):
Variable: A tensor with shape `[batch_size * beam_size, ...]`, whose \ Variable: A tensor with shape `[batch_size * beam_size, ...]`, whose \
data type is same as `x`. data type is same as `x`.
""" """
check_type(x, 'x', (Variable), 'BeamSearchDecoder._merge_batch_beams')
# TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch
return nn.reshape(x, shape=[-1] + list(x.shape[2:])) return nn.reshape(x, shape=[-1] + list(x.shape[2:]))
...@@ -846,16 +850,14 @@ class BeamSearchDecoder(Decoder): ...@@ -846,16 +850,14 @@ class BeamSearchDecoder(Decoder):
`beam_size` times. `beam_size` times.
Parameters: Parameters:
probs(Variable): A tensor with shape `[batch_size, ...]`, representing x(Variable): A tensor with shape `[batch_size, ...]`, The data type
the log probabilities. Its data type should be float32 or float64. should be float32, float64, int32, int64 or bool.
finished(Variable): A tensor with shape `[batch_size, beam_size]`,
representing the finished status for all beams. Its data type
should be bool.
Returns: Returns:
Variable: A tensor with shape `[batch_size, beam_size, ...]`, whose \ Variable: A tensor with shape `[batch_size, beam_size, ...]`, whose \
data type is same as `x`. data type is same as `x`.
""" """
check_type(x, 'x', (Variable), 'BeamSearchDecoder._expand_to_beam_size')
x = nn.unsqueeze(x, [1]) x = nn.unsqueeze(x, [1])
expand_times = [1] * len(x.shape) expand_times = [1] * len(x.shape)
expand_times[1] = self.beam_size expand_times[1] = self.beam_size
...@@ -879,6 +881,9 @@ class BeamSearchDecoder(Decoder): ...@@ -879,6 +881,9 @@ class BeamSearchDecoder(Decoder):
where unfinished beams stay unchanged and finished beams are \ where unfinished beams stay unchanged and finished beams are \
replaced with a tensor with all probability on the EOS token. replaced with a tensor with all probability on the EOS token.
""" """
check_type(probs, 'probs', (Variable), 'BeamSearchDecoder._mask_probs')
check_type(finished, 'finished', (Variable),
'BeamSearchDecoder._mask_probs')
# TODO: use where_op # TODO: use where_op
finished = tensor.cast(finished, dtype=probs.dtype) finished = tensor.cast(finished, dtype=probs.dtype)
probs = nn.elementwise_mul( probs = nn.elementwise_mul(
...@@ -903,6 +908,10 @@ class BeamSearchDecoder(Decoder): ...@@ -903,6 +908,10 @@ class BeamSearchDecoder(Decoder):
Variable: A tensor with the same shape and data type as `x`, \ Variable: A tensor with the same shape and data type as `x`, \
representing the gathered tensor. representing the gathered tensor.
""" """
check_type(x, 'x', (Variable), 'BeamSearchDecoder._gather')
check_type(indices, 'indices', (Variable), 'BeamSearchDecoder._gather')
check_type(batch_size, 'batch_size', (Variable),
'BeamSearchDecoder._gather')
# TODO: compatibility of int32 and int64 # TODO: compatibility of int32 and int64
batch_size = tensor.cast( batch_size = tensor.cast(
batch_size, batch_size,
...@@ -2666,6 +2675,14 @@ def dynamic_gru(input, ...@@ -2666,6 +2675,14 @@ def dynamic_gru(input,
assert in_dygraph_mode( assert in_dygraph_mode(
) is not True, "please use gru instead of dynamic_gru in dygraph mode!" ) is not True, "please use gru instead of dynamic_gru in dygraph mode!"
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'dynamic_gru')
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_gru')
if isinstance(h_0, Variable):
check_variable_and_dtype(h_0, 'h_0', ['float32', 'float64'],
'dynamic_gru')
helper = LayerHelper('gru', **locals()) helper = LayerHelper('gru', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
......
...@@ -17,6 +17,9 @@ from __future__ import print_function ...@@ -17,6 +17,9 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import numpy as np
from paddle.fluid import Program, program_guard
from paddle import fluid
class Segment(object): class Segment(object):
...@@ -229,5 +232,45 @@ class TestChunkEvalOpWithTensorInput(TestChunkEvalOp): ...@@ -229,5 +232,45 @@ class TestChunkEvalOpWithTensorInput(TestChunkEvalOp):
} }
class TestChunkEvalOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_input():
input_data = np.random.random(1, 1).astype("int64")
label_data = np.random.random(1).astype("int64")
fluid.layers.chunk_eval(
input=input_data,
label=label_data,
chunk_scheme="IOB",
num_chunk_types=3)
self.assertRaises(TypeError, test_input)
def test_label():
input_ = fluid.data(
name="input", shape=[None, 1], dtype="int64")
label_data = np.random.random(1).astype("int64")
fluid.layers.chunk_eval(
input=input_,
label=label_data,
chunk_scheme="IOB",
num_chunk_types=3)
self.assertRaises(TypeError, test_label)
def test_type():
in_data = fluid.data(
name="input_", shape=[None, 1], dtype="int32")
label = fluid.data(name="label_", shape=[1], dtype="int64")
fluid.layers.chunk_eval(
input=in_data,
label=label,
chunk_scheme="IOB",
num_chunk_types=3)
self.assertRaises(TypeError, test_type)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,6 +20,8 @@ import math ...@@ -20,6 +20,8 @@ import math
import functools import functools
from op_test import OpTest from op_test import OpTest
from test_lstm_op import ACTIVATION from test_lstm_op import ACTIVATION
from paddle import fluid
from paddle.fluid import Program, program_guard
def gru( def gru(
...@@ -227,5 +229,24 @@ class TestGRUOpReverseOriginMode(TestGRUOp): ...@@ -227,5 +229,24 @@ class TestGRUOpReverseOriginMode(TestGRUOp):
self.origin_mode = True self.origin_mode = True
class TestGruOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
input_data = np.random.random((1, 1536)).astype("float32")
fluid.layers.dynamic_gru(input=input_data, size=512)
self.assertRaises(TypeError, test_Variable)
def test_h_0():
in_data = fluid.data(
name="input", shape=[None, 1536], dtype="float32")
h = fluid.data(name="h", shape=[None, 512], dtype="int32")
fluid.layers.dynamic_gru(input=in_data, size=512, h_0=h)
self.assertRaises(TypeError, test_h_0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册