未验证 提交 f1ab2882 编写于 作者: W Wilber 提交者: GitHub

enhance inference error info. (#27251)

上级 3e20ddf7
...@@ -21,15 +21,21 @@ ...@@ -21,15 +21,21 @@
namespace paddle { namespace paddle {
void ZeroCopyTensor::Reshape(const std::vector<int> &shape) { void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {
PADDLE_ENFORCE(!name_.empty(), PADDLE_ENFORCE_EQ(
name_.empty(), false,
platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can " "Need to SetName first, so that the corresponding tensor can "
"be retrieved."); "be retrieved."));
PADDLE_ENFORCE(input_or_output_, PADDLE_ENFORCE_EQ(input_or_output_, true,
"Can't reshape the output tensor, it is readonly"); platform::errors::PermissionDenied(
PADDLE_ENFORCE(scope_); "Can't reshape the output tensor, it is readonly"));
PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
"The scope should not be nullptr."));
auto *scope = static_cast<framework::Scope *>(scope_); auto *scope = static_cast<framework::Scope *>(scope_);
auto *var = scope->FindVar(name_); auto *var = scope->FindVar(name_);
PADDLE_ENFORCE(var, "No tensor called [%s] in the runtime scope", name_); PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
auto *tensor = var->GetMutable<framework::LoDTensor>(); auto *tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
} }
...@@ -45,8 +51,10 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) { ...@@ -45,8 +51,10 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
EAGER_GET_TENSOR; EAGER_GET_TENSOR;
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
tensor->numel(), 0, tensor->numel(), 0,
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)" platform::errors::PreconditionNotMet(
"function before retrieving mutable_data from input tensor."); "You should call ZeroCopyTensor::Reshape(const std::vector<int> "
"&shape)"
"function before retrieving mutable_data from input tensor."));
switch (static_cast<int>(place)) { switch (static_cast<int>(place)) {
case static_cast<int>(PaddlePlace::kCPU): { case static_cast<int>(PaddlePlace::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace()); return tensor->mutable_data<T>(platform::CPUPlace());
...@@ -55,7 +63,8 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) { ...@@ -55,7 +63,8 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
return tensor->mutable_data<T>(platform::CUDAPlace(device_)); return tensor->mutable_data<T>(platform::CUDAPlace(device_));
} }
default: default:
PADDLE_THROW("Unsupported place: %d", static_cast<int>(place)); PADDLE_THROW(platform::errors::Unavailable("Unsupported place: %d",
static_cast<int>(place)));
break; break;
} }
return nullptr; return nullptr;
...@@ -96,10 +105,11 @@ PaddleDType ZeroCopyTensor::type() const { ...@@ -96,10 +105,11 @@ PaddleDType ZeroCopyTensor::type() const {
template <typename T> template <typename T>
void ZeroCopyTensor::copy_from_cpu(const T *data) { void ZeroCopyTensor::copy_from_cpu(const T *data) {
EAGER_GET_TENSOR; EAGER_GET_TENSOR;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(tensor->numel(), 0,
tensor->numel(), 0, platform::errors::PreconditionNotMet(
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)" "You should call ZeroCopyTensor::Reshape(const "
"function before copying data from cpu."); "std::vector<int> &shape)"
"function before copying data from cpu."));
size_t ele_size = tensor->numel() * sizeof(T); size_t ele_size = tensor->numel() * sizeof(T);
if (place_ == PaddlePlace::kCPU) { if (place_ == PaddlePlace::kCPU) {
...@@ -116,7 +126,8 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) { ...@@ -116,7 +126,8 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) {
memory::Copy(gpu_place, static_cast<void *>(t_data), platform::CPUPlace(), memory::Copy(gpu_place, static_cast<void *>(t_data), platform::CPUPlace(),
data, ele_size, dev_ctx->stream()); data, ele_size, dev_ctx->stream());
#else #else
PADDLE_THROW("Not compiled with CUDA, should not reach here."); PADDLE_THROW(platform::errors::Unavailable(
"Not compiled with CUDA, should not reach here."));
#endif #endif
} }
} }
...@@ -141,7 +152,8 @@ void ZeroCopyTensor::copy_to_cpu(T *data) { ...@@ -141,7 +152,8 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
cudaStreamSynchronize(dev_ctx->stream()); cudaStreamSynchronize(dev_ctx->stream());
#else #else
PADDLE_THROW("Not compile with CUDA, should not reach here."); PADDLE_THROW(platform::errors::Unavailable(
"Not compile with CUDA, should not reach here."));
#endif #endif
} }
} }
...@@ -176,20 +188,27 @@ template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>( ...@@ -176,20 +188,27 @@ template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>(
PaddlePlace place); PaddlePlace place);
void *ZeroCopyTensor::FindTensor() const { void *ZeroCopyTensor::FindTensor() const {
PADDLE_ENFORCE(!name_.empty(), PADDLE_ENFORCE_EQ(
name_.empty(), false,
platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can " "Need to SetName first, so that the corresponding tensor can "
"be retrieved."); "be retrieved."));
PADDLE_ENFORCE(scope_); PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
"The scope should not be nullptr."));
auto *scope = static_cast<framework::Scope *>(scope_); auto *scope = static_cast<framework::Scope *>(scope_);
auto *var = scope->FindVar(name_); auto *var = scope->FindVar(name_);
PADDLE_ENFORCE(var, "No tensor called [%s] in the runtime scope", name_); PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
auto *tensor = var->GetMutable<framework::LoDTensor>(); auto *tensor = var->GetMutable<framework::LoDTensor>();
return tensor; return tensor;
} }
std::vector<int> ZeroCopyTensor::shape() const { std::vector<int> ZeroCopyTensor::shape() const {
EAGER_GET_TENSOR; EAGER_GET_TENSOR;
PADDLE_ENFORCE(tensor_, "not found tensor called %s in the scope", name_); PADDLE_ENFORCE_NOT_NULL(
tensor_, platform::errors::PreconditionNotMet(
"Not found tensor called %s in the scope", name_));
return framework::vectorize<int>(tensor->dims()); return framework::vectorize<int>(tensor->dims());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册