未验证 提交 c5707159 编写于 作者: Y yiicy 提交者: GitHub

[cherry-pick] Variable, fusion_seqexpand_concat_fc, fusion_seqconv_eltadd_relu...

[cherry-pick] Variable, fusion_seqexpand_concat_fc, fusion_seqconv_eltadd_relu error message enhancement, test=develop (#24066)
上级 26a1def9
...@@ -30,11 +30,13 @@ class Variable { ...@@ -30,11 +30,13 @@ class Variable {
static_assert( static_assert(
IsRegisteredVarType<T>(), IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h"); "Not registered type. Please register T inside var_type_traits.h");
PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized."); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId, holder_, platform::errors::NotFound("Variable is not initialized."));
"The Variable type must be %s, but the type it holds is %s.", PADDLE_ENFORCE_EQ(
ToTypeName(VarTypeTrait<T>::kId), holder_->Type(), VarTypeTrait<T>::kId,
ToTypeName(holder_->Type())); platform::errors::InvalidArgument(
"The Variable type must be %s, but the type it holds is %s.",
ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
return *static_cast<const T*>(holder_->Ptr()); return *static_cast<const T*>(holder_->Ptr());
} }
...@@ -45,10 +47,11 @@ class Variable { ...@@ -45,10 +47,11 @@ class Variable {
if (!holder_) { if (!holder_) {
holder_.reset(new PlaceholderImpl<T>()); holder_.reset(new PlaceholderImpl<T>());
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
holder_->Type() == VarTypeTrait<T>::kId, holder_->Type(), VarTypeTrait<T>::kId,
"The Variable type must be %s, but the type it holds is %s.", platform::errors::InvalidArgument(
ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())); "The Variable type must be %s, but the type it holds is %s.",
ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
} }
return static_cast<T*>(holder_->Ptr()); return static_cast<T*>(holder_->Ptr());
} }
...@@ -61,7 +64,8 @@ class Variable { ...@@ -61,7 +64,8 @@ class Variable {
void Clear() { holder_.reset(); } void Clear() { holder_.reset(); }
int Type() const { int Type() const {
PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized."); PADDLE_ENFORCE_NOT_NULL(
holder_, platform::errors::NotFound("Variable is not initialized."));
return holder_->Type(); return holder_->Type();
} }
......
...@@ -23,36 +23,53 @@ namespace operators { ...@@ -23,36 +23,53 @@ namespace operators {
void FusionSeqConvEltAddReluOp::InferShape( void FusionSeqConvEltAddReluOp::InferShape(
framework::InferShapeContext* ctx) const { framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"Input(X) of FusionSeqConvEltAddReluOp should not be null."); "fusion_seqconv_eltadd_relu");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter",
ctx->HasInput("Filter"), "fusion_seqconv_eltadd_relu");
"Input(Filter) of FusionSeqConvEltAddReluOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias",
PADDLE_ENFORCE( "fusion_seqconv_eltadd_relu");
ctx->HasInput("Bias"),
"Input(Bias) of FusionSeqConvEltAddReluOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
PADDLE_ENFORCE( "fusion_seqconv_eltadd_relu");
ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("ColMat"), "Output", "ColMat",
"Output(Out) of FusionSeqConvEltAddReluOp should not be null."); "fusion_seqconv_eltadd_relu");
PADDLE_ENFORCE(
ctx->HasOutput("ColMat"),
"Output(ColMat) of FusionSeqConvEltAddReluOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto w_dims = ctx->GetInputDim("Filter"); auto w_dims = ctx->GetInputDim("Filter");
int context_length = ctx->Attrs().Get<int>("contextLength"); int context_length = ctx->Attrs().Get<int>("contextLength");
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(ctx->Attrs().Get<int>("contextStride"), 1,
ctx->Attrs().Get<int>("contextStride") == 1, platform::errors::InvalidArgument(
"Currently, FusionSeqConvEltAddReluOp only supports contextStride=1."); "Currently, FusionSeqConvEltAddReluOp only supports "
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2, "contextStride=1, but received value is: %d.",
"Input(X, Filter) should be 2-D tensor."); ctx->Attrs().Get<int>("contextStride")));
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(w_dims[0] == context_length * x_dims[1], x_dims.size(), 2,
"Filter's height should be context_length * " platform::errors::InvalidArgument(
"input_hidden_size ."); "Input(X) should be 2-D tensor, but reveiced value is: %d.",
PADDLE_ENFORCE_GT(context_length + ctx->Attrs().Get<int>("contextStart"), 0, x_dims.size()));
"contextStart size should be smaller than contextLength.");
PADDLE_ENFORCE_EQ(
w_dims.size(), 2,
platform::errors::InvalidArgument(
"Filter should be 2-D tensor, but reveiced value is: %d.",
w_dims.size()));
PADDLE_ENFORCE_EQ(w_dims[0], context_length * x_dims[1],
platform::errors::InvalidArgument(
"Filter's height should be equal to context_length * "
"input_hidden_size, but received Filter height is: %d,"
"context_length is: %d, input_hidden_size is: %d.",
w_dims[0], context_length, x_dims[1]));
PADDLE_ENFORCE_GT(
context_length + ctx->Attrs().Get<int>("contextStart"), 0,
platform::errors::InvalidArgument(
"contextStart size should be smaller than contextLength, "
"but received context_length is: %d, contextStart is: "
"%d.",
context_length, ctx->Attrs().Get<int>("contextStart")));
ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]}); ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]});
ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]}); ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]});
...@@ -130,10 +147,17 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> { ...@@ -130,10 +147,17 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
auto x_lod = x->lod(); auto x_lod = x->lod();
auto x_dims = x->dims(); auto x_dims = x->dims();
auto w_dims = w->dims(); auto w_dims = w->dims();
PADDLE_ENFORCE_EQ(b->numel(), w_dims[1], PADDLE_ENFORCE_EQ(
"bias size should be equal to output feature size."); b->numel(), w_dims[1],
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, platform::errors::InvalidArgument(
"Only support one level sequence now."); "bias size should be equal to weights feature size, but received "
"bias size is: %d, weights feature size is: %d.",
b->numel(), w_dims[1]));
PADDLE_ENFORCE_EQ(
x_lod.size(), 1UL,
platform::errors::InvalidArgument(
"Only support one level sequence now, but received value is: %d.",
x_lod.size()));
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* w_data = w->data<T>(); const T* w_data = w->data<T>();
...@@ -183,7 +207,12 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> { ...@@ -183,7 +207,12 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
copy_size -= src_mat_w_sz; copy_size -= src_mat_w_sz;
} }
} else { } else {
PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1); PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1,
platform::errors::InvalidArgument(
"context length must be bigger or equal than "
"up_pad + down_pad + 1, but received context "
"length is: %d, up_pad is: %d, down_pad is: %d.",
context_length, up_pad, down_pad));
std::memset(dst_data, 0, seq_len * col_mat_w_sz); std::memset(dst_data, 0, seq_len * col_mat_w_sz);
dst_data = dst_data + up_pad * src_mat_w; dst_data = dst_data + up_pad * src_mat_w;
int zero_sz = up_pad * src_mat_w_sz; int zero_sz = up_pad * src_mat_w_sz;
......
...@@ -24,38 +24,59 @@ namespace operators { ...@@ -24,38 +24,59 @@ namespace operators {
void FusionSeqExpandConcatFCOp::InferShape( void FusionSeqExpandConcatFCOp::InferShape(
framework::InferShapeContext* ctx) const { framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL,
ctx->Inputs("X").size(), 1UL, platform::errors::InvalidArgument(
"Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1."); "Inputs(X) of FusionSeqExpandConcatFCOp should larger "
PADDLE_ENFORCE( "than 1, but received value is: %d.",
ctx->HasInput("FCWeight"), ctx->Inputs("X").size()));
"Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("FCWeight"), "Input", "FCWeight",
PADDLE_ENFORCE( "fusion_seqexpand_concat_fc");
ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"Output(Out) of FusionSeqExpandConcatFCOp should not be null."); "fusion_seqexpand_concat_fc");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasOutput("FCOut"), "Output", "FCOut",
ctx->HasOutput("FCOut"), "fusion_seqexpand_concat_fc");
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
auto ins_dims = ctx->GetInputsDim("X"); auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(FCWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(
w_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(FCWeight)'s rank must be 2, but received value is: %d.",
w_dims.size()));
const int D = w_dims[1]; const int D = w_dims[1];
int sum = ins_dims[0][1]; int sum = ins_dims[0][1];
for (size_t i = 1; i < ins_dims.size(); ++i) { for (size_t i = 1; i < ins_dims.size(); ++i) {
sum += ins_dims[i][1]; sum += ins_dims[i][1];
} }
PADDLE_ENFORCE_EQ(sum, w_dims[0], PADDLE_ENFORCE_EQ(sum, w_dims[0], platform::errors::InvalidArgument(
"FC height should be sum of all inputs width."); "FC height should be sum of all inputs "
"width, but received FC height is: %d, "
"sum of all inputs width is: %d.",
w_dims[0], sum));
if (ctx->HasInput("FCBias")) { if (ctx->HasInput("FCBias")) {
auto b_dims = ctx->GetInputDim("FCBias"); auto b_dims = ctx->GetInputDim("FCBias");
PADDLE_ENFORCE(b_dims.size() == 1 || b_dims.size() == 2, PADDLE_ENFORCE_EQ(
"b_dims should be 1 or 2, get %d", b_dims.size()); b_dims.size() == 1 || b_dims.size() == 2, true,
platform::errors::InvalidArgument(
"FCBias dim should be 1 or 2, but received value is: %d.",
b_dims.size()));
if (b_dims.size() == 1) { if (b_dims.size() == 1) {
PADDLE_ENFORCE_EQ(b_dims[0], D, "FCBias shapes must be %d.", D); PADDLE_ENFORCE_EQ(b_dims[0], D,
platform::errors::InvalidArgument(
"FCBias shapes must be %d when FCBias dim = 1, but "
"received value is: %d.",
D, b_dims[0]));
} else { } else {
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1x%d.", D); PADDLE_ENFORCE_EQ(b_dims[0], 1,
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1x%d.", D); platform::errors::InvalidArgument(
"FCBias shapes must be 1x%d, when FCBias dim = 2, "
"but received dim[0] is: %d.",
D, b_dims[0]));
PADDLE_ENFORCE_EQ(b_dims[1], D,
platform::errors::InvalidArgument(
"FCBias shapes must be 1x%d, when FCBias dim = 2, "
"but received dim[1] is: %d.",
D, b_dims[1]));
} }
} }
...@@ -133,18 +154,42 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { ...@@ -133,18 +154,42 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
// some check and fcout should be reshape here // some check and fcout should be reshape here
// since infershape can not get lod info // since infershape can not get lod info
PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1."); ref_lod.size(), 1UL,
platform::errors::InvalidArgument(
"Only support input lod size is 1, but received value is: %d.",
ref_lod.size()));
PADDLE_ENFORCE_EQ(
in1_lod.size(), 1UL,
platform::errors::InvalidArgument(
"Only support input lod size is 1, but received value is: %d.",
in1_lod.size()));
PADDLE_ENFORCE_EQ(static_cast<int>(in1_lod[0].size() - 1), N, PADDLE_ENFORCE_EQ(static_cast<int>(in1_lod[0].size() - 1), N,
"Batch size of all inputs should be equal."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(static_cast<int>(in1_lod[0][N]), N, "Batch size of all inputs should be equal to %d, but "
"Seq_length of other inputs should be 1."); "received value is: %d.",
PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size."); N, static_cast<int>(in1_lod[0].size() - 1)));
PADDLE_ENFORCE_EQ(
static_cast<int>(in1_lod[0][N]), N,
platform::errors::InvalidArgument("Seq_length of other inputs should "
"be %d, but received value is: %d.",
N, static_cast<int>(in1_lod[0][N])));
PADDLE_ENFORCE_EQ(
in1_dims[0], N,
platform::errors::InvalidArgument(
"input height should be batch size: %d, but received value is %d.",
N, in1_dims[0]));
for (size_t i = 2; i < ins.size(); ++i) { for (size_t i = 2; i < ins.size(); ++i) {
PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N, PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N,
"All other inputs height should be equal"); platform::errors::InvalidArgument(
"All other inputs height should be equal to %d, "
"but received value is: %d.",
N, ins[i]->dims()[0]));
PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod, PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod,
"All other inputs should have same lod"); platform::errors::InvalidArgument(
"All other inputs should have same lod: %d, but "
"received value is: %d.",
in1_lod, ins[i]->lod()));
} }
fc_out->Resize({N, D}); fc_out->Resize({N, D});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册