提交 c2063217 编写于 作者: Z Zhang Ting 提交者: Tao Luo

optimize error message for "embedding" and "cross_entropy" OP (#18765)

* optimize error message, test=develop

* optimize error message, test=develop
上级 741ce8bb
......@@ -156,7 +156,10 @@ struct HardLabelCrossEntropyForwardFunctor {
auto label = label_[idx];
if (label != ignore_index_) {
PADDLE_ASSERT_MSG(label >= 0 && label < feature_size_,
"The label is out of the range.", label);
"Variable value (label) of "
"OP(fluid.layers.cross_entropy) expected >= 0 "
"and < %ld, but got %ld. Please check label value.",
feature_size_, label);
auto match_x = x_[idx * feature_size_ + label];
y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
match_x_[idx] = match_x;
......
......@@ -32,8 +32,16 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG(id < N, "received id:", id);
PADDLE_ASSERT_MSG(
id >= 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, id);
PADDLE_ASSERT_MSG(
id < N,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, id);
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
......@@ -59,8 +67,16 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG(id < N, "received id:", id);
PADDLE_ASSERT_MSG(
id >= 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, id);
PADDLE_ASSERT_MSG(
id < N,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, id);
const T *out = output + idy * D;
T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
......
......@@ -85,8 +85,18 @@ class LookupTableKernel : public framework::OpKernel<T> {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
PADDLE_ENFORCE_LT(
ids[i], row_number,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
......@@ -181,8 +191,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *ids_data = ids->data<int64_t>();
int N = table_dim[0];
int D = table_dim[1];
int64_t N = table_dim[0];
int64_t D = table_dim[1];
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
......@@ -194,8 +204,16 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
} else {
PADDLE_ENFORCE_LT(ids_data[i], N);
PADDLE_ENFORCE_GE(ids_data[i], 0);
PADDLE_ENFORCE_LT(
ids_data[i], N,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, ids_data[i]);
PADDLE_ENFORCE_GE(
ids_data[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, ids_data[i]);
for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
}
......
......@@ -38,11 +38,11 @@ limitations under the License. */
} while (0)
// NOTE: PADDLE_ASSERT is mainly used in CUDA Kernel or HOSTDEVICE function.
#define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __MSG, __VAL) \
do { \
if (!(_IS_NOT_ERROR)) { \
printf("Exception: %s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, \
__LINE__, TOSTRING(_IS_NOT_ERROR), __MSG, __VAL); \
EXIT(); \
} \
#define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __FORMAT, ...) \
do { \
if (!(_IS_NOT_ERROR)) { \
printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \
__FILE__, __LINE__, TOSTRING(_IS_NOT_ERROR), ##__VA_ARGS__); \
EXIT(); \
} \
} while (0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册