未验证 提交 f1ab2882 编写于 作者: W Wilber 提交者: GitHub

enhance inference error info. (#27251)

上级 3e20ddf7
......@@ -21,15 +21,21 @@
namespace paddle {
void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {
PADDLE_ENFORCE(!name_.empty(),
"Need to SetName first, so that the corresponding tensor can "
"be retrieved.");
PADDLE_ENFORCE(input_or_output_,
"Can't reshape the output tensor, it is readonly");
PADDLE_ENFORCE(scope_);
PADDLE_ENFORCE_EQ(
name_.empty(), false,
platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can "
"be retrieved."));
PADDLE_ENFORCE_EQ(input_or_output_, true,
platform::errors::PermissionDenied(
"Can't reshape the output tensor, it is readonly"));
PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
"The scope should not be nullptr."));
auto *scope = static_cast<framework::Scope *>(scope_);
auto *var = scope->FindVar(name_);
PADDLE_ENFORCE(var, "No tensor called [%s] in the runtime scope", name_);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
auto *tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(shape));
}
......@@ -45,8 +51,10 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
EAGER_GET_TENSOR;
PADDLE_ENFORCE_GT(
tensor->numel(), 0,
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)"
"function before retrieving mutable_data from input tensor.");
platform::errors::PreconditionNotMet(
"You should call ZeroCopyTensor::Reshape(const std::vector<int> "
"&shape)"
"function before retrieving mutable_data from input tensor."));
switch (static_cast<int>(place)) {
case static_cast<int>(PaddlePlace::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace());
......@@ -55,7 +63,8 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
return tensor->mutable_data<T>(platform::CUDAPlace(device_));
}
default:
PADDLE_THROW("Unsupported place: %d", static_cast<int>(place));
PADDLE_THROW(platform::errors::Unavailable("Unsupported place: %d",
static_cast<int>(place)));
break;
}
return nullptr;
......@@ -96,10 +105,11 @@ PaddleDType ZeroCopyTensor::type() const {
template <typename T>
void ZeroCopyTensor::copy_from_cpu(const T *data) {
EAGER_GET_TENSOR;
PADDLE_ENFORCE_GE(
tensor->numel(), 0,
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)"
"function before copying data from cpu.");
PADDLE_ENFORCE_GE(tensor->numel(), 0,
platform::errors::PreconditionNotMet(
"You should call ZeroCopyTensor::Reshape(const "
"std::vector<int> &shape)"
"function before copying data from cpu."));
size_t ele_size = tensor->numel() * sizeof(T);
if (place_ == PaddlePlace::kCPU) {
......@@ -116,7 +126,8 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) {
memory::Copy(gpu_place, static_cast<void *>(t_data), platform::CPUPlace(),
data, ele_size, dev_ctx->stream());
#else
PADDLE_THROW("Not compiled with CUDA, should not reach here.");
PADDLE_THROW(platform::errors::Unavailable(
"Not compiled with CUDA, should not reach here."));
#endif
}
}
......@@ -141,7 +152,8 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
cudaStreamSynchronize(dev_ctx->stream());
#else
PADDLE_THROW("Not compile with CUDA, should not reach here.");
PADDLE_THROW(platform::errors::Unavailable(
"Not compile with CUDA, should not reach here."));
#endif
}
}
......@@ -176,20 +188,27 @@ template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>(
PaddlePlace place);
void *ZeroCopyTensor::FindTensor() const {
PADDLE_ENFORCE(!name_.empty(),
"Need to SetName first, so that the corresponding tensor can "
"be retrieved.");
PADDLE_ENFORCE(scope_);
PADDLE_ENFORCE_EQ(
name_.empty(), false,
platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can "
"be retrieved."));
PADDLE_ENFORCE_NOT_NULL(scope_, platform::errors::PreconditionNotMet(
"The scope should not be nullptr."));
auto *scope = static_cast<framework::Scope *>(scope_);
auto *var = scope->FindVar(name_);
PADDLE_ENFORCE(var, "No tensor called [%s] in the runtime scope", name_);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
auto *tensor = var->GetMutable<framework::LoDTensor>();
return tensor;
}
std::vector<int> ZeroCopyTensor::shape() const {
EAGER_GET_TENSOR;
PADDLE_ENFORCE(tensor_, "not found tensor called %s in the scope", name_);
PADDLE_ENFORCE_NOT_NULL(
tensor_, platform::errors::PreconditionNotMet(
"Not found tensor called %s in the scope", name_));
return framework::vectorize<int>(tensor->dims());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册