未验证 提交 69e3f993 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the error message (#23212)

* refine the error message of tensor_array_read_write Op
上级 5c59d213
...@@ -82,16 +82,23 @@ $$A[i] = T$$ ...@@ -82,16 +82,23 @@ $$A[i] = T$$
class WriteToArrayInferShape : public framework::InferShapeBase { class WriteToArrayInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index"); PADDLE_ENFORCE_EQ(
if (context->IsRuntime()) { context->HasInput("I"), true,
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1, platform::errors::InvalidArgument(
"The number of element of subscript index must be 1"); "Read/Write array operation must set the subscript index."));
}
// TODO(wangchaochaohu) control flow Op do not support runtime infer shape
// Later we add [ontext->GetInputDim("I")) == 1] check when it's supported
if (!context->HasInput("X")) { if (!context->HasInput("X")) {
return; return;
} }
PADDLE_ENFORCE(context->HasOutput("Out"), NotHasOutError()); PADDLE_ENFORCE_EQ(
context->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Read/Write array operation must set the output Tensor "
"to get the result."));
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
// When compile time, we need to: // When compile time, we need to:
...@@ -106,13 +113,6 @@ class WriteToArrayInferShape : public framework::InferShapeBase { ...@@ -106,13 +113,6 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
context->ShareLoD("X", /*->*/ "Out"); context->ShareLoD("X", /*->*/ "Out");
} }
} }
protected:
virtual const char *NotHasXError() const { return "Must set the lod tensor"; }
virtual const char *NotHasOutError() const {
return "Must set the lod tensor array";
}
}; };
class WriteToArrayInferVarType : public framework::VarTypeInference { class WriteToArrayInferVarType : public framework::VarTypeInference {
...@@ -140,10 +140,15 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -140,10 +140,15 @@ class ReadFromArrayOp : public ArrayOp {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *x = scope.FindVar(Input("X")); auto *x = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x != nullptr, "X must be set"); PADDLE_ENFORCE_NOT_NULL(
x,
platform::errors::InvalidArgument(
"X(Input Variable) must be set when we call read array operation"));
auto &x_array = x->Get<framework::LoDTensorArray>(); auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out")); auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set"); PADDLE_ENFORCE_NOT_NULL(out, platform::errors::InvalidArgument(
"Out(Output Varibale) must be set when we "
"call read array operation"));
size_t offset = GetOffset(scope, place); size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) { if (offset < x_array.size()) {
auto *out_tensor = out->GetMutable<framework::LoDTensor>(); auto *out_tensor = out->GetMutable<framework::LoDTensor>();
...@@ -199,15 +204,7 @@ $$T = A[i]$$ ...@@ -199,15 +204,7 @@ $$T = A[i]$$
} }
}; };
class ReadFromArrayInferShape : public WriteToArrayInferShape { class ReadFromArrayInferShape : public WriteToArrayInferShape {};
protected:
const char *NotHasXError() const override {
return "The input array X must be set";
}
const char *NotHasOutError() const override {
return "The output tensor out must be set";
}
};
template <typename T> template <typename T>
class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> { class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册