未验证 提交 0a7bab4e 编写于 作者: F Feiyu Chan 提交者: GitHub

fix error mesage for negative_positive_pair_op and nce_op (#27779)

上级 395cb561
...@@ -104,25 +104,29 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -104,25 +104,29 @@ class NCEKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_probs->numel(), num_total_classes, dist_probs->numel(), num_total_classes,
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistProbs) " "ShapeError: The number of elements in Input(CustomDistProbs) "
"should be equal to the number of total classes. But Received: " "should be equal to the number of total classes. But Received: "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) " "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
"= %d.", "= %d.",
dist_probs->numel(), num_total_classes); dist_probs->numel(), num_total_classes));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_alias->numel(), num_total_classes, dist_alias->numel(), num_total_classes,
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistAlias) " "ShapeError: The number of elements in Input(CustomDistAlias) "
"should be equal to the number of total classes. But Received: " "should be equal to the number of total classes. But Received: "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) " "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
"= %d.", "= %d.",
dist_alias->numel(), num_total_classes); dist_alias->numel(), num_total_classes));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_alias_probs->numel(), num_total_classes, dist_alias_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAliasProbs) " platform::errors::InvalidArgument(
"ShapeError: The number of elements in "
"Input(CustomDistAliasProbs) "
"should be equal to the number of total classes. But Received: " "should be equal to the number of total classes. But Received: "
"Input(CustomDistAliasProbs).numel() = %d, " "Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.", "Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes); dist_alias_probs->numel(), num_total_classes));
const float *probs_data = dist_probs->data<float>(); const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>(); const int *alias_data = dist_alias->data<int>();
...@@ -140,10 +144,11 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -140,10 +144,11 @@ class NCEKernel : public framework::OpKernel<T> {
for (int x = 0; x < sample_labels->numel(); x++) { for (int x = 0; x < sample_labels->numel(); x++) {
PADDLE_ENFORCE_GE(sample_labels_data[x], 0, PADDLE_ENFORCE_GE(sample_labels_data[x], 0,
platform::errors::InvalidArgument(
"ValueError: Every sample label should be " "ValueError: Every sample label should be "
"non-negative. But received: " "non-negative. But received: "
"Input(SampleLabels)[%d] = %d", "Input(SampleLabels)[%d] = %d",
x, sample_labels_data[x]); x, sample_labels_data[x]));
} }
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
...@@ -311,25 +316,29 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -311,25 +316,29 @@ class NCEGradKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_probs->numel(), num_total_classes, dist_probs->numel(), num_total_classes,
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistProbs) " "ShapeError: The number of elements in Input(CustomDistProbs) "
"should be equal to the number of total classes. But Received: " "should be equal to the number of total classes. But Received: "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) " "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
"= %d.", "= %d.",
dist_probs->numel(), num_total_classes); dist_probs->numel(), num_total_classes));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_alias->numel(), num_total_classes, dist_alias->numel(), num_total_classes,
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistAlias) " "ShapeError: The number of elements in Input(CustomDistAlias) "
"should be equal to the number of total classes. But Received: " "should be equal to the number of total classes. But Received: "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) " "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
"= %d.", "= %d.",
dist_alias->numel(), num_total_classes); dist_alias->numel(), num_total_classes));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_alias_probs->numel(), num_total_classes, dist_alias_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAliasProbs) " platform::errors::InvalidArgument(
"ShapeError: The number of elements in "
"Input(CustomDistAliasProbs) "
"should be equal to the number of total classes. But Received: " "should be equal to the number of total classes. But Received: "
"Input(CustomDistAliasProbs).numel() = %d, " "Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.", "Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes); dist_alias_probs->numel(), num_total_classes));
const float *probs_data = dist_probs->data<float>(); const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>(); const int *alias_data = dist_alias->data<int>();
......
...@@ -37,13 +37,15 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { ...@@ -37,13 +37,15 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
if (ctx->HasInput("AccumulatePositivePair") || if (ctx->HasInput("AccumulatePositivePair") ||
ctx->HasInput("AccumulateNegativePair") || ctx->HasInput("AccumulateNegativePair") ||
ctx->HasInput("AccumulateNeutralPair")) { ctx->HasInput("AccumulateNeutralPair")) {
PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") && PADDLE_ENFORCE_EQ(
ctx->HasInput("AccumulatePositivePair") &&
ctx->HasInput("AccumulateNegativePair") && ctx->HasInput("AccumulateNegativePair") &&
ctx->HasInput("AccumulateNeutralPair"), ctx->HasInput("AccumulateNeutralPair"),
true, platform::errors::InvalidArgument(
"All optional inputs(AccumulatePositivePair, " "All optional inputs(AccumulatePositivePair, "
"AccumulateNegativePair, AccumulateNeutralPair) of " "AccumulateNegativePair, AccumulateNeutralPair) of "
"PositiveNegativePairOp are required if one of them is " "PositiveNegativePairOp are required if one of them "
"specified."); "is specified."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->GetInputDim("AccumulatePositivePair"), scalar_dim, ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册