未验证 提交 0a7bab4e 编写于 作者: F Feiyu Chan 提交者: GitHub

fix error mesage for negative_positive_pair_op and nce_op (#27779)

上级 395cb561
......@@ -104,25 +104,29 @@ class NCEKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
dist_probs->numel(), num_total_classes,
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_probs->numel(), num_total_classes);
dist_probs->numel(), num_total_classes));
PADDLE_ENFORCE_EQ(
dist_alias->numel(), num_total_classes,
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistAlias) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_alias->numel(), num_total_classes);
dist_alias->numel(), num_total_classes));
PADDLE_ENFORCE_EQ(
dist_alias_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAliasProbs) "
platform::errors::InvalidArgument(
"ShapeError: The number of elements in "
"Input(CustomDistAliasProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes);
dist_alias_probs->numel(), num_total_classes));
const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>();
......@@ -140,10 +144,11 @@ class NCEKernel : public framework::OpKernel<T> {
for (int x = 0; x < sample_labels->numel(); x++) {
PADDLE_ENFORCE_GE(sample_labels_data[x], 0,
platform::errors::InvalidArgument(
"ValueError: Every sample label should be "
"non-negative. But received: "
"Input(SampleLabels)[%d] = %d",
x, sample_labels_data[x]);
x, sample_labels_data[x]));
}
auto sample_out = context.Output<Tensor>("SampleLogits");
......@@ -311,25 +316,29 @@ class NCEGradKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
dist_probs->numel(), num_total_classes,
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_probs->numel(), num_total_classes);
dist_probs->numel(), num_total_classes));
PADDLE_ENFORCE_EQ(
dist_alias->numel(), num_total_classes,
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistAlias) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_alias->numel(), num_total_classes);
dist_alias->numel(), num_total_classes));
PADDLE_ENFORCE_EQ(
dist_alias_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAliasProbs) "
platform::errors::InvalidArgument(
"ShapeError: The number of elements in "
"Input(CustomDistAliasProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes);
dist_alias_probs->numel(), num_total_classes));
const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>();
......
......@@ -37,13 +37,15 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
if (ctx->HasInput("AccumulatePositivePair") ||
ctx->HasInput("AccumulateNegativePair") ||
ctx->HasInput("AccumulateNeutralPair")) {
PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") &&
PADDLE_ENFORCE_EQ(
ctx->HasInput("AccumulatePositivePair") &&
ctx->HasInput("AccumulateNegativePair") &&
ctx->HasInput("AccumulateNeutralPair"),
true, platform::errors::InvalidArgument(
"All optional inputs(AccumulatePositivePair, "
"AccumulateNegativePair, AccumulateNeutralPair) of "
"PositiveNegativePairOp are required if one of them is "
"specified.");
"PositiveNegativePairOp are required if one of them "
"is specified."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册