未验证 提交 f5f76e61 编写于 作者: Y yiicy 提交者: GitHub

fusion_seqconv_eltadd_relu error message enhancement. (#23554)

上级 b4daea13
......@@ -23,36 +23,53 @@ namespace operators {
void FusionSeqConvEltAddReluOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("Filter"),
"Input(Filter) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("Bias"),
"Input(Bias) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColMat"),
"Output(ColMat) of FusionSeqConvEltAddReluOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"fusion_seqconv_eltadd_relu");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter",
"fusion_seqconv_eltadd_relu");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias",
"fusion_seqconv_eltadd_relu");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"fusion_seqconv_eltadd_relu");
OP_INOUT_CHECK(ctx->HasOutput("ColMat"), "Output", "ColMat",
"fusion_seqconv_eltadd_relu");
auto x_dims = ctx->GetInputDim("X");
auto w_dims = ctx->GetInputDim("Filter");
int context_length = ctx->Attrs().Get<int>("contextLength");
PADDLE_ENFORCE(
ctx->Attrs().Get<int>("contextStride") == 1,
"Currently, FusionSeqConvEltAddReluOp only supports contextStride=1.");
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(w_dims[0] == context_length * x_dims[1],
"Filter's height should be context_length * "
"input_hidden_size .");
PADDLE_ENFORCE_GT(context_length + ctx->Attrs().Get<int>("contextStart"), 0,
"contextStart size should be smaller than contextLength.");
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<int>("contextStride"), 1,
platform::errors::InvalidArgument(
"Currently, FusionSeqConvEltAddReluOp only supports "
"contextStride=1, but received value is: %d.",
ctx->Attrs().Get<int>("contextStride")));
PADDLE_ENFORCE_EQ(
x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X) should be 2-D tensor, but reveiced value is: %d.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
w_dims.size(), 2,
platform::errors::InvalidArgument(
"Filter should be 2-D tensor, but reveiced value is: %d.",
w_dims.size()));
PADDLE_ENFORCE_EQ(w_dims[0], context_length * x_dims[1],
platform::errors::InvalidArgument(
"Filter's height should be equal to context_length * "
"input_hidden_size, but received Filter height is: %d,"
"context_length is: %d, input_hidden_size is: %d.",
w_dims[0], context_length, x_dims[1]));
PADDLE_ENFORCE_GT(
context_length + ctx->Attrs().Get<int>("contextStart"), 0,
platform::errors::InvalidArgument(
"contextStart size should be smaller than contextLength, "
"but received context_length is: %d, contextStart is: "
"%d.",
context_length, ctx->Attrs().Get<int>("contextStart")));
ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]});
ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]});
......@@ -130,10 +147,17 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
auto x_lod = x->lod();
auto x_dims = x->dims();
auto w_dims = w->dims();
PADDLE_ENFORCE_EQ(b->numel(), w_dims[1],
"bias size should be equal to output feature size.");
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
b->numel(), w_dims[1],
platform::errors::InvalidArgument(
"bias size should be equal to weights feature size, but received "
"bias size is: %d, weights feature size is: %d.",
b->numel(), w_dims[1]));
PADDLE_ENFORCE_EQ(
x_lod.size(), 1UL,
platform::errors::InvalidArgument(
"Only support one level sequence now, but received value is: %d.",
x_lod.size()));
const T* x_data = x->data<T>();
const T* w_data = w->data<T>();
......@@ -183,7 +207,12 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
copy_size -= src_mat_w_sz;
}
} else {
PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1);
PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1,
platform::errors::InvalidArgument(
"context length must be bigger or equal than "
"up_pad + down_pad + 1, but received context "
"length is: %d, up_pad is: %d, down_pad is: %d.",
context_length, up_pad, down_pad));
std::memset(dst_data, 0, seq_len * col_mat_w_sz);
dst_data = dst_data + up_pad * src_mat_w;
int zero_sz = up_pad * src_mat_w_sz;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册