diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 1c75424fae7ef3efe3720de7d8e0303661d805ca..8748078109f16eaf02bd9b09af5d1ec993e74e9e 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -104,25 +104,29 @@ class NCEKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( dist_probs->numel(), num_total_classes, - "ShapeError: The number of elements in Input(CustomDistProbs) " - "should be equal to the number of total classes. But Received: " - "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) " - "= %d.", - dist_probs->numel(), num_total_classes); + platform::errors::InvalidArgument( + "ShapeError: The number of elements in Input(CustomDistProbs) " + "should be equal to the number of total classes. But Received: " + "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) " + "= %d.", + dist_probs->numel(), num_total_classes)); PADDLE_ENFORCE_EQ( dist_alias->numel(), num_total_classes, - "ShapeError: The number of elements in Input(CustomDistAlias) " - "should be equal to the number of total classes. But Received: " - "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) " - "= %d.", - dist_alias->numel(), num_total_classes); + platform::errors::InvalidArgument( + "ShapeError: The number of elements in Input(CustomDistAlias) " + "should be equal to the number of total classes. But Received: " + "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) " + "= %d.", + dist_alias->numel(), num_total_classes)); PADDLE_ENFORCE_EQ( dist_alias_probs->numel(), num_total_classes, - "ShapeError: The number of elements in Input(CustomDistAliasProbs) " - "should be equal to the number of total classes. But Received: " - "Input(CustomDistAliasProbs).numel() = %d, " - "Attr(num_total_classes) = %d.", - dist_alias_probs->numel(), num_total_classes); + platform::errors::InvalidArgument( + "ShapeError: The number of elements in " + "Input(CustomDistAliasProbs) " + "should be equal to the number of total classes. But Received: " + "Input(CustomDistAliasProbs).numel() = %d, " + "Attr(num_total_classes) = %d.", + dist_alias_probs->numel(), num_total_classes)); const float *probs_data = dist_probs->data(); const int *alias_data = dist_alias->data(); @@ -140,10 +144,11 @@ class NCEKernel : public framework::OpKernel { for (int x = 0; x < sample_labels->numel(); x++) { PADDLE_ENFORCE_GE(sample_labels_data[x], 0, - "ValueError: Every sample label should be " - "non-negative. But received: " - "Input(SampleLabels)[%d] = %d", - x, sample_labels_data[x]); + platform::errors::InvalidArgument( + "ValueError: Every sample label should be " + "non-negative. But received: " + "Input(SampleLabels)[%d] = %d", + x, sample_labels_data[x])); } auto sample_out = context.Output("SampleLogits"); @@ -311,25 +316,29 @@ class NCEGradKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( dist_probs->numel(), num_total_classes, - "ShapeError: The number of elements in Input(CustomDistProbs) " - "should be equal to the number of total classes. But Received: " - "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) " - "= %d.", - dist_probs->numel(), num_total_classes); + platform::errors::InvalidArgument( + "ShapeError: The number of elements in Input(CustomDistProbs) " + "should be equal to the number of total classes. But Received: " + "Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) " + "= %d.", + dist_probs->numel(), num_total_classes)); PADDLE_ENFORCE_EQ( dist_alias->numel(), num_total_classes, - "ShapeError: The number of elements in Input(CustomDistAlias) " - "should be equal to the number of total classes. But Received: " - "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) " - "= %d.", - dist_alias->numel(), num_total_classes); + platform::errors::InvalidArgument( + "ShapeError: The number of elements in Input(CustomDistAlias) " + "should be equal to the number of total classes. But Received: " + "Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) " + "= %d.", + dist_alias->numel(), num_total_classes)); PADDLE_ENFORCE_EQ( dist_alias_probs->numel(), num_total_classes, - "ShapeError: The number of elements in Input(CustomDistAliasProbs) " - "should be equal to the number of total classes. But Received: " - "Input(CustomDistAliasProbs).numel() = %d, " - "Attr(num_total_classes) = %d.", - dist_alias_probs->numel(), num_total_classes); + platform::errors::InvalidArgument( + "ShapeError: The number of elements in " + "Input(CustomDistAliasProbs) " + "should be equal to the number of total classes. But Received: " + "Input(CustomDistAliasProbs).numel() = %d, " + "Attr(num_total_classes) = %d.", + dist_alias_probs->numel(), num_total_classes)); const float *probs_data = dist_probs->data(); const int *alias_data = dist_alias->data(); diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc index e42c4666e110f5ad86272cccf136a570b8dce100..75d1b36c7d6a8dc71042566430ecce7b22b06a36 100644 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ b/paddle/fluid/operators/positive_negative_pair_op.cc @@ -37,13 +37,15 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { if (ctx->HasInput("AccumulatePositivePair") || ctx->HasInput("AccumulateNegativePair") || ctx->HasInput("AccumulateNeutralPair")) { - PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") && - ctx->HasInput("AccumulateNegativePair") && - ctx->HasInput("AccumulateNeutralPair"), - "All optional inputs(AccumulatePositivePair, " - "AccumulateNegativePair, AccumulateNeutralPair) of " - "PositiveNegativePairOp are required if one of them is " - "specified."); + PADDLE_ENFORCE_EQ( + ctx->HasInput("AccumulatePositivePair") && + ctx->HasInput("AccumulateNegativePair") && + ctx->HasInput("AccumulateNeutralPair"), + true, platform::errors::InvalidArgument( + "All optional inputs(AccumulatePositivePair, " + "AccumulateNegativePair, AccumulateNeutralPair) of " + "PositiveNegativePairOp are required if one of them " + "is specified.")); PADDLE_ENFORCE_EQ( ctx->GetInputDim("AccumulatePositivePair"), scalar_dim, platform::errors::InvalidArgument(