未验证 提交 1d95a0fb 编写于 作者: F Feiyu Chan 提交者: GitHub

fix error message for nce_op (#27863)

上级 4237fefe
...@@ -135,7 +135,12 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -135,7 +135,12 @@ class NCEKernel : public framework::OpKernel<T> {
alias_data, alias_probs_data, seed); alias_data, alias_probs_data, seed);
break; break;
} }
default: { PADDLE_THROW("Unsupported SamplerType."); } default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported SamplerType. SamplerType should be 0: Uniform, "
"1: LogUniform or 2: CostumDist. Received SamplerType: %d",
sampler_type));
}
} }
PrepareSamples<DeviceContext, T>(context, sampler); PrepareSamples<DeviceContext, T>(context, sampler);
...@@ -225,9 +230,9 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -225,9 +230,9 @@ class NCEKernel : public framework::OpKernel<T> {
weight, false, table_names, epmap, weight, false, table_names, epmap,
context, local_scope); context, local_scope);
#else #else
PADDLE_THROW( PADDLE_THROW(platform::errors::PreconditionNotMet(
"paddle is not compiled with distribute support, can not do " "paddle is not compiled with distribute support, can not do "
"parameter prefetch!"); "parameter prefetch!"));
#endif #endif
auto weight_mat = EigenMatrix<T>::From( auto weight_mat = EigenMatrix<T>::From(
...@@ -347,7 +352,12 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -347,7 +352,12 @@ class NCEGradKernel : public framework::OpKernel<T> {
alias_data, alias_probs_data, seed); alias_data, alias_probs_data, seed);
break; break;
} }
default: { PADDLE_THROW("Unsupported SamplerType."); } default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported SamplerType. SamplerType should be 0: Uniform, "
"1: LogUniform or 2: CostumDist. Received SamplerType: %d",
sampler_type));
}
} }
// T b = 1. / num_total_classes * num_neg_samples; // T b = 1. / num_total_classes * num_neg_samples;
...@@ -409,9 +419,9 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -409,9 +419,9 @@ class NCEGradKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<SelectedRows>("Weight"); auto *table_t = context.Input<SelectedRows>("Weight");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW( PADDLE_THROW(platform::errors::InvalidArgument(
"The parameter Weight of a NCE_OP " "The parameter Weight of a NCE_OP "
"must be either LoDTensor or SelectedRows"); "must be either LoDTensor or SelectedRows"));
} }
auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight")); auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册