未验证 提交 0a7bab4e 编写于 作者: F Feiyu Chan 提交者: GitHub

fix error mesage for negative_positive_pair_op and nce_op (#27779)

上级 395cb561
...@@ -104,25 +104,29 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -104,25 +104,29 @@ class NCEKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_probs->numel(), num_total_classes, dist_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistProbs) " platform::errors::InvalidArgument(
"should be equal to the number of total classes. But Received: " "ShapeError: The number of elements in Input(CustomDistProbs) "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) " "should be equal to the number of total classes. But Received: "
"= %d.", "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
dist_probs->numel(), num_total_classes); "= %d.",
dist_probs->numel(), num_total_classes));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_alias->numel(), num_total_classes, dist_alias->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAlias) " platform::errors::InvalidArgument(
"should be equal to the number of total classes. But Received: " "ShapeError: The number of elements in Input(CustomDistAlias) "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) " "should be equal to the number of total classes. But Received: "
"= %d.", "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
dist_alias->numel(), num_total_classes); "= %d.",
dist_alias->numel(), num_total_classes));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_alias_probs->numel(), num_total_classes, dist_alias_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAliasProbs) " platform::errors::InvalidArgument(
"should be equal to the number of total classes. But Received: " "ShapeError: The number of elements in "
"Input(CustomDistAliasProbs).numel() = %d, " "Input(CustomDistAliasProbs) "
"Attr(num_total_classes) = %d.", "should be equal to the number of total classes. But Received: "
dist_alias_probs->numel(), num_total_classes); "Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes));
const float *probs_data = dist_probs->data<float>(); const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>(); const int *alias_data = dist_alias->data<int>();
...@@ -140,10 +144,11 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -140,10 +144,11 @@ class NCEKernel : public framework::OpKernel<T> {
for (int x = 0; x < sample_labels->numel(); x++) { for (int x = 0; x < sample_labels->numel(); x++) {
PADDLE_ENFORCE_GE(sample_labels_data[x], 0, PADDLE_ENFORCE_GE(sample_labels_data[x], 0,
"ValueError: Every sample label should be " platform::errors::InvalidArgument(
"non-negative. But received: " "ValueError: Every sample label should be "
"Input(SampleLabels)[%d] = %d", "non-negative. But received: "
x, sample_labels_data[x]); "Input(SampleLabels)[%d] = %d",
x, sample_labels_data[x]));
} }
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
...@@ -311,25 +316,29 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -311,25 +316,29 @@ class NCEGradKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_probs->numel(), num_total_classes, dist_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistProbs) " platform::errors::InvalidArgument(
"should be equal to the number of total classes. But Received: " "ShapeError: The number of elements in Input(CustomDistProbs) "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) " "should be equal to the number of total classes. But Received: "
"= %d.", "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
dist_probs->numel(), num_total_classes); "= %d.",
dist_probs->numel(), num_total_classes));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_alias->numel(), num_total_classes, dist_alias->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAlias) " platform::errors::InvalidArgument(
"should be equal to the number of total classes. But Received: " "ShapeError: The number of elements in Input(CustomDistAlias) "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) " "should be equal to the number of total classes. But Received: "
"= %d.", "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
dist_alias->numel(), num_total_classes); "= %d.",
dist_alias->numel(), num_total_classes));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dist_alias_probs->numel(), num_total_classes, dist_alias_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAliasProbs) " platform::errors::InvalidArgument(
"should be equal to the number of total classes. But Received: " "ShapeError: The number of elements in "
"Input(CustomDistAliasProbs).numel() = %d, " "Input(CustomDistAliasProbs) "
"Attr(num_total_classes) = %d.", "should be equal to the number of total classes. But Received: "
dist_alias_probs->numel(), num_total_classes); "Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes));
const float *probs_data = dist_probs->data<float>(); const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>(); const int *alias_data = dist_alias->data<int>();
......
...@@ -37,13 +37,15 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { ...@@ -37,13 +37,15 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
if (ctx->HasInput("AccumulatePositivePair") || if (ctx->HasInput("AccumulatePositivePair") ||
ctx->HasInput("AccumulateNegativePair") || ctx->HasInput("AccumulateNegativePair") ||
ctx->HasInput("AccumulateNeutralPair")) { ctx->HasInput("AccumulateNeutralPair")) {
PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") && PADDLE_ENFORCE_EQ(
ctx->HasInput("AccumulateNegativePair") && ctx->HasInput("AccumulatePositivePair") &&
ctx->HasInput("AccumulateNeutralPair"), ctx->HasInput("AccumulateNegativePair") &&
"All optional inputs(AccumulatePositivePair, " ctx->HasInput("AccumulateNeutralPair"),
"AccumulateNegativePair, AccumulateNeutralPair) of " true, platform::errors::InvalidArgument(
"PositiveNegativePairOp are required if one of them is " "All optional inputs(AccumulatePositivePair, "
"specified."); "AccumulateNegativePair, AccumulateNeutralPair) of "
"PositiveNegativePairOp are required if one of them "
"is specified."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->GetInputDim("AccumulatePositivePair"), scalar_dim, ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册