未验证 提交 fd9b7bdb 编写于 作者: Z zhangchunle 提交者: GitHub

Op (FusedEmbeddingSeqPool) error message enhancement. (#23454)

上级 16315d3d
...@@ -24,30 +24,47 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { ...@@ -24,30 +24,47 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "FusedEmbeddingSeqPool");
"Input W of FusedEmbeddingSeqPoolOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids",
PADDLE_ENFORCE(ctx->HasInput("Ids"), "FusedEmbeddingSeqPool");
"Input Ids of FusedEmbeddingSeqPoolOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
PADDLE_ENFORCE(ctx->HasOutput("Out"), "FusedEmbeddingSeqPool");
"Output of FusedEmbeddingSeqPoolOp should not be null.");
auto table_dims = ctx->GetInputDim("W"); auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
const std::string& combiner = ctx->Attrs().Get<std::string>("combiner"); const std::string& combiner = ctx->Attrs().Get<std::string>("combiner");
PADDLE_ENFORCE_EQ(table_dims.size(), 2); PADDLE_ENFORCE_EQ(table_dims.size(), 2,
PADDLE_ENFORCE_GE(ids_dims.size(), 1, platform::errors::InvalidArgument(
"The dim size of the 'Ids' tensor must greater than 1."); "The dim size of the input tensor 'W' should be 2. "
PADDLE_ENFORCE_EQ(ids_dims[ids_dims.size() - 1], 1, "But received W's size = %d.",
"The last dimension of the 'Ids' tensor must be 1."); table_dims.size()));
PADDLE_ENFORCE_GE(
ids_dims.size(), 1,
platform::errors::InvalidArgument(
"The dim size of the input tensor 'Ids' should be greater "
"than or equal to 1. But received Ids's size = %d.",
ids_dims.size()));
PADDLE_ENFORCE_EQ(
ids_dims[ids_dims.size() - 1], 1,
platform::errors::InvalidArgument(
"The last dimension of the input tensor 'Ids' should be 1. "
"But received Ids's size in the last dimension = %d.",
ids_dims[ids_dims.size() - 1]));
// we only support sum now // we only support sum now
PADDLE_ENFORCE_EQ(combiner, "sum"); PADDLE_ENFORCE_EQ(combiner, "sum",
platform::errors::Unimplemented(
"The pooling type of sequence_pool only support sum "
"now. So the 'combiner' must be 'sum'."));
int64_t last_dim = FusedEmbeddingSeqPoolLastDim(table_dims, ids_dims); int64_t last_dim = FusedEmbeddingSeqPoolLastDim(table_dims, ids_dims);
// in compile time, the lod level of ids must be 1 // in compile time, the lod level of ids must be 1
framework::VarDesc* ids_desc = framework::VarDesc* ids_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]); boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1); PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1,
platform::errors::InvalidArgument(
"In compile time, the LoD Level of Ids should be 1. "
"But received the LoD Level of Ids = %d.",
ids_desc->GetLoDLevel()));
// in compile time, the shape from Ids -> output // in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size] // should be [-1, 1] -> [-1, embedding_size]
......
...@@ -90,8 +90,17 @@ struct EmbeddingVSumFunctor { ...@@ -90,8 +90,17 @@ struct EmbeddingVSumFunctor {
int64_t idx_width = ids_t->numel() / ids_lod.back(); int64_t idx_width = ids_t->numel() / ids_lod.back();
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_LE(table_width * idx_width, out_width); PADDLE_ENFORCE_LE(table_width * idx_width, out_width,
PADDLE_ENFORCE_GT(ids_lod.size(), 1UL, "The LoD[0] could NOT be empty"); platform::errors::InvalidArgument(
"table_width * idx_width should be less than or "
"equal to out_width. But received "
"table_width * idx_width = %s, out_width = %d.",
table_width * idx_width, out_width));
PADDLE_ENFORCE_GT(ids_lod.size(), 1UL,
platform::errors::InvalidArgument(
"The tensor ids's LoD[0] should be greater than 1. "
"But received the ids's LoD[0] = %d.",
ids_lod.size()));
jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width, jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width,
out_width, jit::SeqPoolType::kSum); out_width, jit::SeqPoolType::kSum);
...@@ -130,7 +139,10 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -130,7 +139,10 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
const auto &ids_lod = ids_t->lod(); const auto &ids_lod = ids_t->lod();
// in run time, the LoD of ids must be 1 // in run time, the LoD of ids must be 1
PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL, PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL,
"The LoD level of Input(Ids) must be 1"); platform::errors::InvalidArgument(
"The LoD level of Input(Ids) should be 1. But "
"received Ids's LoD level = %d.",
ids_lod.size()));
int64_t batch_size = ids_lod[0].size() - 1; int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output // in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, last_dim] // should be [seq_length, 1] -> [batch_size, last_dim]
...@@ -244,7 +256,10 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -244,7 +256,10 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
const auto &ids_lod = ids->lod(); const auto &ids_lod = ids->lod();
PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL, PADDLE_ENFORCE_EQ(ids_lod.size(), 1UL,
"The LoD level of Input(Ids) must be 1"); platform::errors::InvalidArgument(
"The LoD level of Input(Ids) should be 1. But "
"received Ids's LoD level = %d.",
ids_lod.size()));
const std::vector<uint64_t> offset = ids_lod[0]; const std::vector<uint64_t> offset = ids_lod[0];
auto len = ids->numel(); auto len = ids->numel();
int idx_width = len / offset.back(); int idx_width = len / offset.back();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册