未验证 提交 f11af6a9 编写于 作者: X xiaogang 提交者: GitHub

enhance attention_lstm and param_attr error message (#23678)

* enhance attention_lstm and param_attr error message
* fix: fix param_attr type check
上级 600cb8c8
......@@ -23,97 +23,119 @@ namespace paddle {
namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Assert only one Input(X) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Assert only one Input(C0) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Assert only one Input(LSTMWeight) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Assert only one Input(LSTMBias) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Assert only one Input(AttentionWeight) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Assert only one Output(Hidden) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Assert only one Output(Cell) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Assert only one Output(AttentionedX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AttentionLstm");
OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "AttentionLstm");
OP_INOUT_CHECK(ctx->HasInput("LSTMWeight"), "Input", "LSTMWeight",
"AttentionLstm");
OP_INOUT_CHECK(ctx->HasInput("LSTMBias"), "Input", "LSTMBias",
"AttentionLstm");
OP_INOUT_CHECK(ctx->HasInput("AttentionWeight"), "Input", "AttentionWeight",
"AttentionLstm");
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "AttentionLstm");
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "AttentionLstm");
OP_INOUT_CHECK(ctx->HasOutput("AttentionedX"), "Output", "AttentionedX",
"AttentionLstm");
OP_INOUT_CHECK(ctx->HasOutput("AttentionFCOut"), "Output", "AttentionFCOut",
"AttentionLstm");
OP_INOUT_CHECK(ctx->HasOutput("LSTMX"), "Output", "LSTMX", "AttentionLstm");
OP_INOUT_CHECK(ctx->HasOutput("LSTMOUT"), "Output", "LSTMOUT",
"AttentionLstm");
auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
"Input(X)'s rank must be 2."));
auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
PADDLE_ENFORCE_EQ(
w_dims.size(), 2,
platform::errors::InvalidArgument("Input(LSTMWeight)'s rank must be 2."));
PADDLE_ENFORCE_EQ(
w_dims[0], D + M,
platform::errors::InvalidArgument(
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D));
auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims.size(), 2, platform::errors::InvalidArgument(
"Input(LSTMBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(b_dims[0], 1,
platform::errors::InvalidArgument(
"LSTMBias dims should be 1 x %d.", 4 * D));
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D,
platform::errors::InvalidArgument(
"LSTMBias dims should be 1 x %d.", 4 * D));
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, platform::errors::InvalidArgument(
"Input(C0)'s rank must be 2."));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
PADDLE_ENFORCE_EQ(c_dims[1], D, platform::errors::InvalidArgument(
"C0 dims should be N x %d.", D));
}
if (ctx->HasInput("H0")) {
auto h_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, platform::errors::InvalidArgument(
"Input(H0)'s rank must be 2."));
if (ctx->IsRuntime() ||
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
PADDLE_ENFORCE(h_dims == c_dims,
PADDLE_ENFORCE_EQ(h_dims, c_dims,
platform::errors::InvalidArgument(
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
"should be the same."));
}
}
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2.");
platform::errors::InvalidArgument(
"Input(AttentionWeight)'s rank must be 2."));
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
platform::errors::InvalidArgument(
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
platform::errors::InvalidArgument(
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2.");
platform::errors::InvalidArgument(
"Input(AttentionBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1.");
platform::errors::InvalidArgument(
"AttentionBias shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1.");
platform::errors::InvalidArgument(
"AttentionBias shapes must be 1 * 1."));
}
if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
platform::errors::InvalidArgument(
"Input(AttentionScalar)'s rank must be 2."));
PADDLE_ENFORCE_EQ(dims[0], 1, platform::errors::InvalidArgument(
"AttentionScalar shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(dims[1], 1, platform::errors::InvalidArgument(
"AttentionScalar shapes must be 1 * 1."));
}
if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE(
ctx->HasInput("AttentionScalar"),
"AttentionScalar should not be null when have AttentionScalarBias.");
OP_INOUT_CHECK(ctx->HasInput("AttentionScalar"), "Input", "AttentionScalar",
"AttentionLstm");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
platform::errors::InvalidArgument(
"Input(AttentionScalarBias)'s rank must be 2."));
PADDLE_ENFORCE_EQ(dims[0], 1,
platform::errors::InvalidArgument(
"AttentionScalarBias shapes must be 1 * 1."));
PADDLE_ENFORCE_EQ(dims[1], 1,
platform::errors::InvalidArgument(
"AttentionScalarBias shapes must be 1 * 1."));
}
framework::DDim out_dims({x_dims[0], D});
......@@ -301,8 +323,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
int len = x_lod[0][i + 1] - x_lod[0][i];
max_seq_len = max_seq_len < len ? len : max_seq_len;
}
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, "Input(X)'s lod size must be 1.");
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, platform::errors::InvalidArgument(
"Input(X)'s lod size must be 1."));
PADDLE_ENFORCE_EQ(
c0->dims()[0], N,
platform::errors::InvalidArgument("C0 dims should be %d x %d.", N, D));
fc_out->Resize({max_seq_len, 1});
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
......
......@@ -16,9 +16,11 @@ from __future__ import print_function
import six
import warnings
import sys
from .initializer import Initializer, Xavier, Constant
from .regularizer import WeightDecayRegularizer
from paddle.fluid.data_feeder import check_type
__all__ = [
'ParamAttr',
......@@ -77,8 +79,17 @@ class ParamAttr(object):
regularizer=None,
trainable=True,
do_model_average=True):
if sys.version_info.major == 2:
check_type(name, "name", (str, type(None), unicode), "ParamAttr")
else:
check_type(name, "name", (str, type(None)), "ParamAttr")
check_type(learning_rate, "learning_rate", (float, int), "ParamAttr")
check_type(trainable, "trainable", (bool), "ParamAttr")
check_type(do_model_average, "do_model_average", (bool), "ParamAttr")
self.name = name
if isinstance(self.name, six.string_types) and self.name == "":
if self.name == "":
raise ValueError("name of ParamAttr can not be empty str")
self.initializer = initializer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部