未验证 提交 f11af6a9 编写于 作者: X xiaogang 提交者: GitHub

enhance attention_lstm and param_attr error message (#23678)

* enhance attention_lstm and param_attr error message
* fix: fix param_attr type check
上级 600cb8c8
...@@ -23,97 +23,119 @@ namespace paddle { ...@@ -23,97 +23,119 @@ namespace paddle {
namespace operators { namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AttentionLstm");
"Assert only one Input(X) of AttentionLSTM."); OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "AttentionLstm");
PADDLE_ENFORCE(ctx->HasInput("C0"), OP_INOUT_CHECK(ctx->HasInput("LSTMWeight"), "Input", "LSTMWeight",
"Assert only one Input(C0) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), OP_INOUT_CHECK(ctx->HasInput("LSTMBias"), "Input", "LSTMBias",
"Assert only one Input(LSTMWeight) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), OP_INOUT_CHECK(ctx->HasInput("AttentionWeight"), "Input", "AttentionWeight",
"Assert only one Input(LSTMBias) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Assert only one Input(AttentionWeight) of AttentionLSTM."); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "AttentionLstm");
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "AttentionLstm");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), OP_INOUT_CHECK(ctx->HasOutput("AttentionedX"), "Output", "AttentionedX",
"Assert only one Output(Hidden) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), OP_INOUT_CHECK(ctx->HasOutput("AttentionFCOut"), "Output", "AttentionFCOut",
"Assert only one Output(Cell) of AttentionLSTM."); "AttentionLstm");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), OP_INOUT_CHECK(ctx->HasOutput("LSTMX"), "Output", "LSTMX", "AttentionLstm");
"Assert only one Output(AttentionedX) of AttentionLSTM."); OP_INOUT_CHECK(ctx->HasOutput("LSTMOUT"), "Output", "LSTMOUT",
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), "AttentionLstm");
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1]; const int M = x_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
"Input(X)'s rank must be 2."));
auto w_dims = ctx->GetInputDim("LSTMWeight"); auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4; const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(w_dims[0], D + M, w_dims.size(), 2,
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D); platform::errors::InvalidArgument("Input(LSTMWeight)'s rank must be 2."));
PADDLE_ENFORCE_EQ(
w_dims[0], D + M,
platform::errors::InvalidArgument(
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D));
auto b_dims = ctx->GetInputDim("LSTMBias"); auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); "Input(LSTMBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); PADDLE_ENFORCE_EQ(b_dims[0], 1,
platform::errors::InvalidArgument(
"LSTMBias dims should be 1 x %d.", 4 * D));
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D,
platform::errors::InvalidArgument(
"LSTMBias dims should be 1 x %d.", 4 * D));
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims.size(), 2, platform::errors::InvalidArgument(
"Input(C0)'s rank must be 2."));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); PADDLE_ENFORCE_EQ(c_dims[1], D, platform::errors::InvalidArgument(
"C0 dims should be N x %d.", D));
} }
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2."); PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, platform::errors::InvalidArgument(
"Input(H0)'s rank must be 2."));
if (ctx->IsRuntime() || if (ctx->IsRuntime() ||
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) { (framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE_EQ(h_dims, c_dims,
"The dimension of Input(H0) and Input(C0) " platform::errors::InvalidArgument(
"should be the same."); "The dimension of Input(H0) and Input(C0) "
"should be the same."));
} }
} }
auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2."); platform::errors::InvalidArgument(
"Input(AttentionWeight)'s rank must be 2."));
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); platform::errors::InvalidArgument(
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D); platform::errors::InvalidArgument(
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
if (ctx->HasInput("AttentionBias")) { if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias"); auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2."); platform::errors::InvalidArgument(
"Input(AttentionBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1."); platform::errors::InvalidArgument(
"AttentionBias shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1."); platform::errors::InvalidArgument(
"AttentionBias shapes must be 1 * 1."));
} }
if (ctx->HasInput("AttentionScalar")) { if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar"); auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); "Input(AttentionScalar)'s rank must be 2."));
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[0], 1, platform::errors::InvalidArgument(
"AttentionScalar shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(dims[1], 1, platform::errors::InvalidArgument(
"AttentionScalar shapes must be 1 * 1."));
} }
if (ctx->HasInput("AttentionScalarBias")) { if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias"); auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasInput("AttentionScalar"), "Input", "AttentionScalar",
ctx->HasInput("AttentionScalar"), "AttentionLstm");
"AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2, PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); "Input(AttentionScalarBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); PADDLE_ENFORCE_EQ(dims[0], 1,
platform::errors::InvalidArgument(
"AttentionScalarBias shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(dims[1], 1,
platform::errors::InvalidArgument(
"AttentionScalarBias shapes must be 1 * 1."));
} }
framework::DDim out_dims({x_dims[0], D}); framework::DDim out_dims({x_dims[0], D});
...@@ -301,8 +323,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -301,8 +323,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
int len = x_lod[0][i + 1] - x_lod[0][i]; int len = x_lod[0][i + 1] - x_lod[0][i];
max_seq_len = max_seq_len < len ? len : max_seq_len; max_seq_len = max_seq_len < len ? len : max_seq_len;
} }
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, "Input(X)'s lod size must be 1."); PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); "Input(X)'s lod size must be 1."));
PADDLE_ENFORCE_EQ(
c0->dims()[0], N,
platform::errors::InvalidArgument("C0 dims should be %d x %d.", N, D));
fc_out->Resize({max_seq_len, 1}); fc_out->Resize({max_seq_len, 1});
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
......
...@@ -16,9 +16,11 @@ from __future__ import print_function ...@@ -16,9 +16,11 @@ from __future__ import print_function
import six import six
import warnings import warnings
import sys
from .initializer import Initializer, Xavier, Constant from .initializer import Initializer, Xavier, Constant
from .regularizer import WeightDecayRegularizer from .regularizer import WeightDecayRegularizer
from paddle.fluid.data_feeder import check_type
__all__ = [ __all__ = [
'ParamAttr', 'ParamAttr',
...@@ -77,8 +79,17 @@ class ParamAttr(object): ...@@ -77,8 +79,17 @@ class ParamAttr(object):
regularizer=None, regularizer=None,
trainable=True, trainable=True,
do_model_average=True): do_model_average=True):
if sys.version_info.major == 2:
check_type(name, "name", (str, type(None), unicode), "ParamAttr")
else:
check_type(name, "name", (str, type(None)), "ParamAttr")
check_type(learning_rate, "learning_rate", (float, int), "ParamAttr")
check_type(trainable, "trainable", (bool), "ParamAttr")
check_type(do_model_average, "do_model_average", (bool), "ParamAttr")
self.name = name self.name = name
if isinstance(self.name, six.string_types) and self.name == "": if self.name == "":
raise ValueError("name of ParamAttr can not be empty str") raise ValueError("name of ParamAttr can not be empty str")
self.initializer = initializer self.initializer = initializer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册