From 1d95a0fbc382d4e3650419e72a8818bc48c8f651 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Tue, 13 Oct 2020 10:52:36 +0800 Subject: [PATCH] fix error message for nce_op (#27863) --- paddle/fluid/operators/nce_op.h | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 8748078109..3357db8454 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -135,7 +135,12 @@ class NCEKernel : public framework::OpKernel { alias_data, alias_probs_data, seed); break; } - default: { PADDLE_THROW("Unsupported SamplerType."); } + default: { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported SamplerType. SamplerType should be 0: Uniform, " + "1: LogUniform or 2: CostumDist. Received SamplerType: %d", + sampler_type)); + } } PrepareSamples(context, sampler); @@ -225,9 +230,9 @@ class NCEKernel : public framework::OpKernel { weight, false, table_names, epmap, context, local_scope); #else - PADDLE_THROW( + PADDLE_THROW(platform::errors::PreconditionNotMet( "paddle is not compiled with distribute support, can not do " - "parameter prefetch!"); + "parameter prefetch!")); #endif auto weight_mat = EigenMatrix::From( @@ -347,7 +352,12 @@ class NCEGradKernel : public framework::OpKernel { alias_data, alias_probs_data, seed); break; } - default: { PADDLE_THROW("Unsupported SamplerType."); } + default: { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported SamplerType. SamplerType should be 0: Uniform, " + "1: LogUniform or 2: CostumDist. Received SamplerType: %d", + sampler_type)); + } } // T b = 1. / num_total_classes * num_neg_samples; @@ -409,9 +419,9 @@ class NCEGradKernel : public framework::OpKernel { auto *table_t = context.Input("Weight"); table_dim = table_t->value().dims(); } else { - PADDLE_THROW( + PADDLE_THROW(platform::errors::InvalidArgument( "The parameter Weight of a NCE_OP " - "must be either LoDTensor or SelectedRows"); + "must be either LoDTensor or SelectedRows")); } auto d_w = context.Output(framework::GradVarName("Weight")); -- GitLab