未验证 提交 3ae939e4 编写于 作者: T Tao Luo 提交者: GitHub

unify PADDLE_ASSERT_MSG into PADDLE_ENFORCE(error_message) (#19631)

* remove assert.h

* change PADDLE_ASSERT_MSG to PADDLE_ENFORCE

test=develop

* fix tensorrt paddle_enforce

test=develop
上级 af692c91
......@@ -20,7 +20,6 @@
#include <type_traits>
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h"
......
......@@ -111,10 +111,9 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]);
PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, h_odatas,
d_output_ptrs_.size() * sizeof(float*),
cudaMemcpyHostToDevice,
stream) == cudaSuccess);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(
output_ptrs, h_odatas, d_output_ptrs_.size() * sizeof(float*),
cudaMemcpyHostToDevice, stream));
int outer_rows = outer_rows_ * batchSize;
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include <iostream>
#include "paddle/fluid/operators/center_loss_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
......@@ -31,8 +30,8 @@ __global__ void ComputeDifferent(T *centers_diff, const T *X, const T *centers,
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG(id < N, "received id:", id);
PADDLE_ENFORCE(id >= 0, "received id:", id);
PADDLE_ENFORCE(id < N, "received id:", id);
T *out = centers_diff + idy * D;
const T *x = X + idy * D;
const T *cent = centers + id * D;
......@@ -53,8 +52,8 @@ __global__ void UpdateCenters(T *centers, T *centers_diff, const int64_t *ids,
while (idy < K) {
int count = 1;
int64_t id = ids[idy];
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG(id < N, "received id:", id);
PADDLE_ENFORCE(id >= 0, "received id:", id);
PADDLE_ENFORCE(id < N, "received id:", id);
for (int i = 0; i < K; i++) {
if (ids[i] == id) {
......
......@@ -155,11 +155,11 @@ struct HardLabelCrossEntropyForwardFunctor {
HOSTDEVICE void operator()(int64_t idx) const {
auto label = label_[idx];
if (label != ignore_index_) {
PADDLE_ASSERT_MSG(label >= 0 && label < feature_size_,
"Variable value (label) of "
"OP(fluid.layers.cross_entropy) expected >= 0 "
"and < %ld, but got %ld. Please check label value.",
feature_size_, label);
PADDLE_ENFORCE(label >= 0 && label < feature_size_,
"Variable value (label) of "
"OP(fluid.layers.cross_entropy) expected >= 0 "
"and < %ld, but got %ld. Please check label value.",
feature_size_, label);
auto match_x = x_[idx * feature_size_ + label];
y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
match_x_[idx] = match_x;
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lookup_table_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
......@@ -32,12 +31,12 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT_MSG(
PADDLE_ENFORCE(
id >= 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, id);
PADDLE_ASSERT_MSG(
PADDLE_ENFORCE(
id < N,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
......@@ -67,12 +66,12 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT_MSG(
PADDLE_ENFORCE(
id >= 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
N, id);
PADDLE_ASSERT_MSG(
PADDLE_ENFORCE(
id < N,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input value.",
......
......@@ -7,7 +7,7 @@ function(math_library TARGET)
set(cc_srcs)
set(cu_srcs)
set(hip_srcs)
set(math_common_deps device_context framework_proto)
set(math_common_deps device_context framework_proto enforce)
set(multiValueArgs DEPS)
cmake_parse_arguments(math_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
......
......@@ -27,10 +27,10 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
const int ignore_index) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT_MSG(label[i] >= 0 && label[i] < D || label[i] == ignore_index,
"label[%d] expected >= 0 and < %ld, or == %ld, but got "
"%ld. Please check input value.",
i, D, ignore_index, label[i]);
PADDLE_ENFORCE(label[i] >= 0 && label[i] < D || label[i] == ignore_index,
"label[%d] expected >= 0 and < %ld, or == %ld, but got "
"%ld. Please check input value.",
i, D, ignore_index, label[i]);
Y[i] = ignore_index == label[i]
? static_cast<T>(0)
: -math::TolerableValue<T>()(real_log(X[i * D + label[i]]));
......
......@@ -25,8 +25,8 @@ namespace math {
template <typename T>
struct TolerableValue {
HOSTDEVICE T operator()(const T& x) const {
PADDLE_ASSERT_MSG(std::is_floating_point<T>::value,
"TolerableValue should be float in cross_entropy.");
PADDLE_ENFORCE(std::is_floating_point<T>::value,
"TolerableValue should be float in cross_entropy.");
const T kApproInf = 1e20;
if (x == INFINITY) return kApproInf;
......
......@@ -37,10 +37,10 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ASSERT_MSG(out_index < out_c_stride,
"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
PADDLE_ENFORCE(out_index < out_c_stride,
"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
output_data[out_offset + out_index] = input_data[i];
}
}
......@@ -62,10 +62,10 @@ __global__ void KernelUnpool2dMaxGrad(
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ASSERT_MSG(out_index < out_c_stride,
"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
PADDLE_ENFORCE(out_index < out_c_stride,
"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
input_grad[i] = output_grad[out_offset + out_index];
}
}
......
......@@ -29,10 +29,10 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
struct CheckLabelValue {
HOSTDEVICE T operator()(const T& val) const {
PADDLE_ASSERT_MSG(val == static_cast<T>(0) || val == static_cast<T>(1),
"LabelValue of modified_huber_loss_op expected to be 0 "
"or 1, but got %ld. Please check input value.",
val);
PADDLE_ENFORCE(val == static_cast<T>(0) || val == static_cast<T>(1),
"LabelValue of modified_huber_loss_op expected to be 0 "
"or 1, but got %ld. Please check input value.",
val);
}
};
......
......@@ -60,16 +60,16 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out,
size_t offset_i = offsets[i];
if (i == rank - 1) {
PADDLE_ASSERT_MSG(x_stride == 1,
"When i:%d == rank:%d - 1, x_stride of random_crop_op "
"expected to be 1, but got %ld. Please check input "
"value.",
i, rank, x_stride);
PADDLE_ASSERT_MSG(out_stride == 1,
"When i:%d == rank:%d - 1, out_stride of random_crop_op "
"expected to be 1, but got %ld. Please check input "
"value.",
i, rank, out_stride);
PADDLE_ENFORCE(x_stride == 1,
"When i:%d == rank:%d - 1, x_stride of random_crop_op "
"expected to be 1, but got %ld. Please check input "
"value.",
i, rank, x_stride);
PADDLE_ENFORCE(out_stride == 1,
"When i:%d == rank:%d - 1, out_stride of random_crop_op "
"expected to be 1, but got %ld. Please check input "
"value.",
i, rank, out_stride);
x += offset_i;
for (size_t j = 0; j < out_dim_i; ++j) {
*out++ = *x++;
......
......@@ -34,8 +34,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct TolerableValue {
HOSTDEVICE T operator()(const T& x) const {
PADDLE_ASSERT_MSG(std::is_floating_point<T>::value,
"TolerableValue should be float in sample_logits_op.");
PADDLE_ENFORCE(std::is_floating_point<T>::value,
"TolerableValue should be float in sample_logits_op.");
const T kApproInf = 1e20;
if (x == INFINITY) return kApproInf;
if (x == -INFINITY) return -kApproInf;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#if defined(__CUDA_ARCH__)
#include <stdio.h>
#define EXIT() asm("trap;")
#else
#include <assert.h>
#define EXIT() throw std::runtime_error("Exception encounter.")
#endif
// NOTE: PADDLE_ASSERT is mainly used in CUDA Kernel or HOSTDEVICE function.
#define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __FORMAT, ...) \
do { \
if (!(_IS_NOT_ERROR)) { \
printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \
__FILE__, __LINE__, TOSTRING(_IS_NOT_ERROR), ##__VA_ARGS__); \
EXIT(); \
} \
} while (0)
......@@ -289,6 +289,19 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
#if defined(__CUDA_ARCH__)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...) \
do { \
if (!(_IS_NOT_ERROR)) { \
printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \
__FILE__, __LINE__, #_IS_NOT_ERROR, ##__VA_ARGS__); \
asm("trap;"); \
} \
} while (0)
#else
#define PADDLE_ENFORCE(COND, ...) \
do { \
auto __cond__ = (COND); \
......@@ -302,6 +315,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
} \
} \
} while (0)
#endif
#ifdef PADDLE_WITH_CUDA
#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册