diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h old mode 100644 new mode 100755 index 309ba46cfa3b35fd4f6a4a889965b717b890a303..2f136784f6b4758175ac38dd003ed4f068dd4bcd --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -156,7 +156,10 @@ struct HardLabelCrossEntropyForwardFunctor { auto label = label_[idx]; if (label != ignore_index_) { PADDLE_ASSERT_MSG(label >= 0 && label < feature_size_, - "The label is out of the range.", label); + "Variable value (label) of " + "OP(fluid.layers.cross_entropy) expected >= 0 " + "and < %ld, but got %ld. Please check label value.", + feature_size_, label); auto match_x = x_[idx * feature_size_ + label]; y_[idx] = -math::TolerableValue()(real_log(match_x)); match_x_[idx] = match_x; diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 8716662f158bd939755feda71e0ac8ea5748ac26..cb432e6d3e91bfeff725e64a13909077977bdb11 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -32,8 +32,16 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids, while (idy < K) { int64_t id = ids[idy]; - PADDLE_ASSERT_MSG(id >= 0, "received id:", id); - PADDLE_ASSERT_MSG(id < N, "received id:", id); + PADDLE_ASSERT_MSG( + id >= 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, id); + PADDLE_ASSERT_MSG( + id < N, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, id); T *out = output + idy * D; const T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { @@ -59,8 +67,16 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids, while (idy < K) { int64_t id = ids[idy]; - PADDLE_ASSERT_MSG(id >= 0, "received id:", id); - PADDLE_ASSERT_MSG(id < N, "received id:", id); + PADDLE_ASSERT_MSG( + id >= 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, id); + PADDLE_ASSERT_MSG( + id < N, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, id); const T *out = output + idy * D; T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 62e298e066948c93a84a131a0dffc0a1d53f2a5b..b3e48638c6c0bacac32895c6da1cfe7597a28744 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -85,8 +85,18 @@ class LookupTableKernel : public framework::OpKernel { if (padding_idx != kNoPadding && ids[i] == padding_idx) { memset(output + i * row_width, 0, row_width * sizeof(T)); } else { - PADDLE_ENFORCE_LT(ids[i], row_number); - PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); + PADDLE_ENFORCE_LT( + ids[i], row_number, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, ids[i]); + PADDLE_ENFORCE_GE( + ids[i], 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, ids[i]); memcpy(output + i * row_width, table + ids[i] * row_width, row_width * sizeof(T)); } @@ -181,8 +191,8 @@ class LookupTableGradKernel : public framework::OpKernel { auto *ids_data = ids->data(); - int N = table_dim[0]; - int D = table_dim[1]; + int64_t N = table_dim[0]; + int64_t D = table_dim[1]; auto *d_output_data = d_output->data(); auto *d_table_data = d_table->mutable_data(context.GetPlace()); @@ -194,8 +204,16 @@ class LookupTableGradKernel : public framework::OpKernel { // the gradient of padding_idx should be 0, already done by memset, so // do nothing. } else { - PADDLE_ENFORCE_LT(ids_data[i], N); - PADDLE_ENFORCE_GE(ids_data[i], 0); + PADDLE_ENFORCE_LT( + ids_data[i], N, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, ids_data[i]); + PADDLE_ENFORCE_GE( + ids_data[i], 0, + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input value.", + N, ids_data[i]); for (int j = 0; j < D; ++j) { d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; } diff --git a/paddle/fluid/platform/assert.h b/paddle/fluid/platform/assert.h index e3884a985e08ad94fc95cfa65329f848e0715bd1..83c08b8266a08649287b12364fc27b0cedb0e695 100644 --- a/paddle/fluid/platform/assert.h +++ b/paddle/fluid/platform/assert.h @@ -38,11 +38,11 @@ limitations under the License. */ } while (0) // NOTE: PADDLE_ASSERT is mainly used in CUDA Kernel or HOSTDEVICE function. -#define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __MSG, __VAL) \ - do { \ - if (!(_IS_NOT_ERROR)) { \ - printf("Exception: %s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, \ - __LINE__, TOSTRING(_IS_NOT_ERROR), __MSG, __VAL); \ - EXIT(); \ - } \ +#define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __FORMAT, ...) \ + do { \ + if (!(_IS_NOT_ERROR)) { \ + printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \ + __FILE__, __LINE__, TOSTRING(_IS_NOT_ERROR), ##__VA_ARGS__); \ + EXIT(); \ + } \ } while (0)