未验证 提交 f91c37e6 编写于 作者: A Aurelius84 提交者: GitHub

Refine error message of MatchMatrix and PyramidHash (#27484)

上级 8f7bb52b
......@@ -28,34 +28,54 @@ using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"X(Input) of MatchMatrix should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Y(Input) of MatchMatrix should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
"W(Input) of MatchMatrix should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Out(Output) of MatchMatrix should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Tmp"), true,
"Tmp(Output) of MatchMatrix should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "match_matrix_tensor");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "match_matrix_tensor");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "match_matrix_tensor");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "match_matrix_tensor");
OP_INOUT_CHECK(ctx->HasOutput("Tmp"), "Output", "Tmp", "match_matrix_tensor");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) can't be less than 2.");
platform::errors::InvalidArgument(
"The dimensions of Input(X) should be equal to 2, "
"but received %d.",
x_dims.size()));
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(y_dims.size(), 2,
"The rank of Input(Y) can't be less than 2.");
platform::errors::InvalidArgument(
"The dimensions of Input(Y) should be equal to 2, "
"but received %d.",
y_dims.size()));
auto w_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(w_dims.size(), 3UL, "W should be 3-D tensor");
PADDLE_ENFORCE_EQ(w_dims.size(), 3,
platform::errors::InvalidArgument(
"The dimensions of Input(W) should be equal to 3, "
"but received %d.",
w_dims.size()));
int dim_t = ctx->Attrs().Get<int>("dim_t");
PADDLE_ENFORCE_EQ(w_dims[0], x_dims[1],
"W 's shape must satisfy: W[0] = X[1]");
PADDLE_ENFORCE_EQ(w_dims[1], dim_t, "W 's shape must satisfy: W[1] = dim_t");
PADDLE_ENFORCE_EQ(w_dims[2], y_dims[1],
"W 's shape must satisfy: W[2] = Y[1]");
PADDLE_ENFORCE_EQ(
w_dims[0], x_dims[1],
platform::errors::InvalidArgument(
"The first dimension of Input(W) should be equal to the second "
"dimension of Input(X). But received the first dimension of Input(W) "
"is %d, the second dimension of Input(X) is %d.",
w_dims[0], x_dims[1]));
PADDLE_ENFORCE_EQ(
w_dims[1], dim_t,
platform::errors::InvalidArgument(
"The second dimension of Input(W) should be equal to 'dim_t', but "
"received the second dimension of Input(W) is %d, 'dim_t' is %d.",
w_dims[1], dim_t));
PADDLE_ENFORCE_EQ(
w_dims[2], y_dims[1],
platform::errors::InvalidArgument(
"The last dimension of Input(W) should be equal to "
"the second dimension of Input(Y). But received the last dimension "
"of Input(W) is %d, the second dimension of Input(Y) is %d.",
w_dims[2], y_dims[1]));
int64_t out_dim_0 = -1;
int64_t tmp_dim_0 = -1;
......@@ -63,27 +83,52 @@ void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
framework::Variable* x_var =
BOOST_GET(framework::Variable*, ctx->GetInputVarPtrs("X")[0]);
const auto& x_lod = x_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_EQ(x_lod.empty(), false, "The Input(X) must hold lod info.");
PADDLE_ENFORCE_EQ(x_lod.empty(), false,
platform::errors::InvalidArgument(
"The Input(X) should hold LoD information, but "
"received Input(X).lod() is empty."));
const auto& x_lod_0 = x_lod[0];
PADDLE_ENFORCE_GE(x_lod_0.size(), 2,
"The Input(X)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
x_dims[0], static_cast<int64_t>(x_lod_0.back()),
"The Input(X)'s lod info mismatches the actual tensor shape.");
platform::errors::InvalidArgument(
"The dimensions of Input(X)'s LoD data should be "
"equal to 2, but received %d.",
x_lod_0.size()));
PADDLE_ENFORCE_EQ(x_dims[0], static_cast<int64_t>(x_lod_0.back()),
platform::errors::InvalidArgument(
"The last element of Input(X)'s LoD data should be "
"equal to the first dimension of Input(X). "
"But received the last element of Input(X)'s LoD "
"data is %d, the first dimension of Input(X) is %d.",
x_lod_0.back(), x_dims[0]));
framework::Variable* y_var =
BOOST_GET(framework::Variable*, ctx->GetInputVarPtrs("Y")[0]);
const auto& y_lod = y_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE_EQ(y_lod.empty(), false, "The Input(Y) must hold lod info.");
PADDLE_ENFORCE_EQ(y_lod.empty(), false,
platform::errors::InvalidArgument(
"The Input(Y) should hold LoD information, but "
"received Input(Y).lod() is empty."));
const auto& y_lod_0 = y_lod[0];
PADDLE_ENFORCE_GE(y_lod_0.size(), 2,
"The Input(Y)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
y_dims[0], static_cast<int64_t>(y_lod_0.back()),
"The Input(Y)'s lod info mismatches the actual tensor shape.");
platform::errors::InvalidArgument(
"The dimensions of Input(Y)'s LoD data should be "
"equal to 2, but received %d.",
y_lod_0.size()));
PADDLE_ENFORCE_EQ(y_dims[0], static_cast<int64_t>(y_lod_0.back()),
platform::errors::InvalidArgument(
"The last element of Input(Y)'s LoD data should be "
"equal to the first dimension of Input(Y). "
"But received the last element of Input(Y)'s LoD "
"data is %d, the first dimension of Input(Y) is %d.",
y_lod_0.back(), y_dims[0]));
PADDLE_ENFORCE_EQ(x_lod_0.size(), y_lod_0.size(),
"The Length of X and Y must be equal.");
platform::errors::InvalidArgument(
"The dimensions of Input(X)'s and Input(Y)'s LoD "
"data should be equal. "
"But received the dimensions of Input(X)'s LoD is "
"%d, the dimensions of Input(Y)'s LoD is %d.",
x_lod_0.size(), y_lod_0.size()));
out_dim_0 = 0;
for (size_t i = 1; i < x_lod_0.size(); i++) {
......@@ -98,10 +143,18 @@ void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
// compile time
framework::VarDesc* x_desc =
BOOST_GET(framework::VarDesc*, ctx->GetInputVarPtrs("X")[0]);
PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1);
PADDLE_ENFORCE_GE(
x_desc->GetLoDLevel(), 1,
platform::errors::InvalidArgument("The LoD level of Input(X) should be "
"greater than 1, but reviced %d.",
x_desc->GetLoDLevel()));
framework::VarDesc* y_desc =
BOOST_GET(framework::VarDesc*, ctx->GetInputVarPtrs("Y")[0]);
PADDLE_ENFORCE_GE(y_desc->GetLoDLevel(), 1);
PADDLE_ENFORCE_GE(
y_desc->GetLoDLevel(), 1,
platform::errors::InvalidArgument("The LoD level of Input(Y) should be "
"greater than 1, but reviced %d.",
y_desc->GetLoDLevel()));
ctx->ShareLoD("X", "Out");
}
......@@ -115,14 +168,11 @@ void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const {
void MatchMatrixTensorOpGrad::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Input(Y) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
"Input(W) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "match_matrix_tensor_grad");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "match_matrix_tensor_grad");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "match_matrix_tensor_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "match_matrix_tensor_grad");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
......
......@@ -285,13 +285,21 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
if (use_filter) {
if (white_list_len != 0) {
_filter = (math::bloomfilter*)_blobs_1->data<float>();
PADDLE_ENFORCE_EQ(math::bloomfilter_check(_filter), 1,
"white filter not load");
PADDLE_ENFORCE_EQ(
math::bloomfilter_check(_filter), 1,
platform::errors::PreconditionNotMet(
"The white filter is not loaded successfully, please make sure "
"'white_list_len': %d is valid for Input(WhiteList).",
white_list_len));
}
if (black_list_len != 0) {
_black_filter = (math::bloomfilter*)_blobs_2->data<float>();
PADDLE_ENFORCE_EQ(math::bloomfilter_check(_black_filter), 1,
"black filter not load");
PADDLE_ENFORCE_EQ(
math::bloomfilter_check(_black_filter), 1,
platform::errors::PreconditionNotMet(
"The black filter is not loaded successfully, please make sure "
"'black_list_len': %d is valid for Input(BlackList).",
black_list_len));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册