diff --git a/paddle/fluid/operators/bpr_loss_op.h b/paddle/fluid/operators/bpr_loss_op.h index f9570e4e2ed0d9ac8739410eb7cd7397ad09fae4..a01666596b62cd0f8433e6bc290ed92ba77966ad 100644 --- a/paddle/fluid/operators/bpr_loss_op.h +++ b/paddle/fluid/operators/bpr_loss_op.h @@ -28,7 +28,7 @@ using Tensor = framework::Tensor; template struct TolerableValue { HOSTDEVICE T operator()(const T& x) const { - PADDLE_ASSERT(std::is_floating_point::value); + PADDLE_ENFORCE_EQ(std::is_floating_point::value, true); const T kApproInf = 1e20; if (x == INFINITY) return kApproInf; if (x == -INFINITY) return -kApproInf; diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 5bc05257aa9d3db7881330ca4547da439dab03bd..59f4485aa92c8dbaf219369ae0e0406758462920 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -27,7 +27,10 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, const int ignore_index) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - PADDLE_ASSERT(label[i] >= 0 && label[i] < D || label[i] == ignore_index); + PADDLE_ASSERT_MSG(label[i] >= 0 && label[i] < D || label[i] == ignore_index, + "label[%d] expected >= 0 and < %ld, or == %ld, but got " + "%ld. Please check input value.", + i, D, ignore_index, label[i]); Y[i] = ignore_index == label[i] ? static_cast(0) : -math::TolerableValue()(real_log(X[i * D + label[i]])); diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index 48082a7273dd7ad713fbc964ebbd1445ed887cdd..23d2cf4fd9f532f9217b63e84d928fce5e8e0acb 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -25,7 +25,8 @@ namespace math { template struct TolerableValue { HOSTDEVICE T operator()(const T& x) const { - PADDLE_ASSERT(std::is_floating_point::value); + PADDLE_ASSERT_MSG(std::is_floating_point::value, + "TolerableValue should be float in cross_entropy."); const T kApproInf = 1e20; if (x == INFINITY) return kApproInf; diff --git a/paddle/fluid/operators/math/unpooling.cu b/paddle/fluid/operators/math/unpooling.cu index c467ae8427d8f461b332eed8075631ed7e47b96e..de6ee7c7cd6e8305af9386bb3d30d19c9846b690 100644 --- a/paddle/fluid/operators/math/unpooling.cu +++ b/paddle/fluid/operators/math/unpooling.cu @@ -37,7 +37,10 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, int cidx = boffset / in_c_stride; int out_offset = bidx * out_n_stride + cidx * out_c_stride; int out_index = indices_data[i]; - PADDLE_ASSERT(out_index < out_c_stride); + PADDLE_ASSERT_MSG(out_index < out_c_stride, + "out_index < out_c_stride. Expected %ld < %ld, but got " + "%ld >= %ld. Please check input value.", + out_index, out_c_stride, out_index, out_c_stride); output_data[out_offset + out_index] = input_data[i]; } } @@ -59,7 +62,10 @@ __global__ void KernelUnpool2dMaxGrad( int cidx = boffset / in_c_stride; int out_offset = bidx * out_n_stride + cidx * out_c_stride; int out_index = indices_data[i]; - PADDLE_ASSERT(out_index < out_c_stride); + PADDLE_ASSERT_MSG(out_index < out_c_stride, + "out_index < out_c_stride. Expected %ld < %ld, but got " + "%ld >= %ld. Please check input value.", + out_index, out_c_stride, out_index, out_c_stride); input_grad[i] = output_grad[out_offset + out_index]; } } diff --git a/paddle/fluid/operators/modified_huber_loss_op.h b/paddle/fluid/operators/modified_huber_loss_op.h index d2b6d0c4bab1619f10f68bd9bf22f975c4c2dfd7..d7dbf791a7ee13d87836bb6b0292a44eafa982d9 100644 --- a/paddle/fluid/operators/modified_huber_loss_op.h +++ b/paddle/fluid/operators/modified_huber_loss_op.h @@ -29,7 +29,10 @@ using EigenVector = framework::EigenVector; template struct CheckLabelValue { HOSTDEVICE T operator()(const T& val) const { - PADDLE_ASSERT(val == static_cast(0) || val == static_cast(1)); + PADDLE_ASSERT_MSG(val == static_cast(0) || val == static_cast(1), + "LabelValue of modified_huber_loss_op expected to be 0 " + "or 1, but got %ld. Please check input value.", + val); } }; diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h index ee034b270527376fc268b8a868f90db52c51848a..e1457eccb5b4d15941ad135e8a20a89ddddd26d8 100644 --- a/paddle/fluid/operators/random_crop_op.h +++ b/paddle/fluid/operators/random_crop_op.h @@ -60,7 +60,16 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out, size_t offset_i = offsets[i]; if (i == rank - 1) { - PADDLE_ASSERT(x_stride == 1 && out_stride == 1); + PADDLE_ASSERT_MSG(x_stride == 1, + "When i:%d == rank:%d - 1, x_stride of random_crop_op " + "expected to be 1, but got %ld. Please check input " + "value.", + i, rank, x_stride); + PADDLE_ASSERT_MSG(out_stride == 1, + "When i:%d == rank:%d - 1, out_stride of random_crop_op " + "expected to be 1, but got %ld. Please check input " + "value.", + i, rank, out_stride); x += offset_i; for (size_t j = 0; j < out_dim_i; ++j) { *out++ = *x++; diff --git a/paddle/fluid/operators/sample_logits_op.h b/paddle/fluid/operators/sample_logits_op.h index b55a24863cc09d5f80e07aedbbb5b3d9ac99e69e..7e78fca714de6cd8a982030b13e89bab0039cc19 100644 --- a/paddle/fluid/operators/sample_logits_op.h +++ b/paddle/fluid/operators/sample_logits_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -33,7 +34,8 @@ using EigenMatrix = framework::EigenMatrix; template struct TolerableValue { HOSTDEVICE T operator()(const T& x) const { - PADDLE_ASSERT(std::is_floating_point::value); + PADDLE_ASSERT_MSG(std::is_floating_point::value, + "TolerableValue should be float in sample_logits_op."); const T kApproInf = 1e20; if (x == INFINITY) return kApproInf; if (x == -INFINITY) return -kApproInf; diff --git a/paddle/fluid/platform/assert.h b/paddle/fluid/platform/assert.h index 83c08b8266a08649287b12364fc27b0cedb0e695..2883bd5ed34834692cb0b637da372cb8e343d9bf 100644 --- a/paddle/fluid/platform/assert.h +++ b/paddle/fluid/platform/assert.h @@ -28,15 +28,6 @@ limitations under the License. */ #define EXIT() throw std::runtime_error("Exception encounter.") #endif -#define PADDLE_ASSERT(_IS_NOT_ERROR) \ - do { \ - if (!(_IS_NOT_ERROR)) { \ - printf("Exception: %s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, \ - TOSTRING(_IS_NOT_ERROR)); \ - EXIT(); \ - } \ - } while (0) - // NOTE: PADDLE_ASSERT is mainly used in CUDA Kernel or HOSTDEVICE function. #define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __FORMAT, ...) \ do { \ diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 3021b8e7ef3bfbcf89e241ee8ea00d03a78d4ca7..d656c12d8f8640bcd7ea50bcf8c1aa3fa0852931 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -48,7 +48,6 @@ if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_pipeline) endif() list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 -list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184 list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185 list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957