未验证 提交 dc225ed2 编写于 作者: L liym27 提交者: GitHub

OP (tensor_array_read_write) error message enhancement. test=develop (#23468)

上级 3cb5623d
...@@ -31,9 +31,14 @@ class ArrayOp : public framework::OperatorBase { ...@@ -31,9 +31,14 @@ class ArrayOp : public framework::OperatorBase {
size_t GetOffset(const framework::Scope &scope, size_t GetOffset(const framework::Scope &scope,
const platform::Place &place) const { const platform::Place &place) const {
auto *i = scope.FindVar(Input("I")); auto *i = scope.FindVar(Input("I"));
PADDLE_ENFORCE(i != nullptr, "I must be set"); PADDLE_ENFORCE_NOT_NULL(
i, platform::errors::NotFound("Input(I) is not found."));
auto &i_tensor = i->Get<framework::LoDTensor>(); auto &i_tensor = i->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1); PADDLE_ENFORCE_EQ(i_tensor.numel(), 1,
platform::errors::InvalidArgument(
"Input(I) must have numel 1. "
"But received %d, and it's shape is [%s].",
i_tensor.numel(), i_tensor.dims()));
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......
...@@ -83,8 +83,7 @@ class WriteToArrayInferShape : public framework::InferShapeBase { ...@@ -83,8 +83,7 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
context->HasInput("I"), true, context->HasInput("I"), true,
platform::errors::InvalidArgument( platform::errors::NotFound("Input(I) of WriteToArrayOp is not found."));
"Read/Write array operation must set the subscript index."));
// TODO(wangchaochaohu) control flow Op do not support runtime infer shape // TODO(wangchaochaohu) control flow Op do not support runtime infer shape
// Later we add [ontext->GetInputDim("I")) == 1] check when it's supported // Later we add [ontext->GetInputDim("I")) == 1] check when it's supported
...@@ -93,11 +92,9 @@ class WriteToArrayInferShape : public framework::InferShapeBase { ...@@ -93,11 +92,9 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
return; return;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
context->HasOutput("Out"), true, platform::errors::NotFound(
platform::errors::InvalidArgument( "Output(Out) of WriteToArrayOp is not found."));
"Read/Write array operation must set the output Tensor "
"to get the result."));
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
// When compile time, we need to: // When compile time, we need to:
...@@ -139,15 +136,14 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -139,15 +136,14 @@ class ReadFromArrayOp : public ArrayOp {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(x,
x, platform::errors::NotFound(
platform::errors::InvalidArgument( "Input(X) of ReadFromArrayOp is not found."));
"X(Input Variable) must be set when we call read array operation"));
auto &x_array = x->Get<framework::LoDTensorArray>(); auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out")); auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE_NOT_NULL(out, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(
"Out(Output Varibale) must be set when we " out, platform::errors::NotFound(
"call read array operation")); "Output(Out) of ReadFromArrayOp is not found."));
size_t offset = GetOffset(scope, place); size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) { if (offset < x_array.size()) {
auto *out_tensor = out->GetMutable<framework::LoDTensor>(); auto *out_tensor = out->GetMutable<framework::LoDTensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册