diff --git a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc index e8b69871cb66055d18f7315935221a30f7332901..504a72fafe6e83d9b2f90c58db1c1c7c9f06e6ee 100644 --- a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc @@ -82,16 +82,23 @@ $$A[i] = T$$ class WriteToArrayInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { - PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index"); - if (context->IsRuntime()) { - PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1, - "The number of element of subscript index must be 1"); - } + PADDLE_ENFORCE_EQ( + context->HasInput("I"), true, + platform::errors::InvalidArgument( + "Read/Write array operation must set the subscript index.")); + + // TODO(wangchaochaohu) control flow Op do not support runtime infer shape + // Later we add [ontext->GetInputDim("I")) == 1] check when it's supported + if (!context->HasInput("X")) { return; } - PADDLE_ENFORCE(context->HasOutput("Out"), NotHasOutError()); + PADDLE_ENFORCE_EQ( + context->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Read/Write array operation must set the output Tensor " + "to get the result.")); context->SetOutputDim("Out", context->GetInputDim("X")); // When compile time, we need to: @@ -106,13 +113,6 @@ class WriteToArrayInferShape : public framework::InferShapeBase { context->ShareLoD("X", /*->*/ "Out"); } } - - protected: - virtual const char *NotHasXError() const { return "Must set the lod tensor"; } - - virtual const char *NotHasOutError() const { - return "Must set the lod tensor array"; - } }; class WriteToArrayInferVarType : public framework::VarTypeInference { @@ -140,10 +140,15 @@ class ReadFromArrayOp : public ArrayOp { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { auto *x = scope.FindVar(Input("X")); - PADDLE_ENFORCE(x != nullptr, "X must be set"); + PADDLE_ENFORCE_NOT_NULL( + x, + platform::errors::InvalidArgument( + "X(Input Variable) must be set when we call read array operation")); auto &x_array = x->Get(); auto *out = scope.FindVar(Output("Out")); - PADDLE_ENFORCE(out != nullptr, "Out must be set"); + PADDLE_ENFORCE_NOT_NULL(out, platform::errors::InvalidArgument( + "Out(Output Varibale) must be set when we " + "call read array operation")); size_t offset = GetOffset(scope, place); if (offset < x_array.size()) { auto *out_tensor = out->GetMutable(); @@ -199,15 +204,7 @@ $$T = A[i]$$ } }; -class ReadFromArrayInferShape : public WriteToArrayInferShape { - protected: - const char *NotHasXError() const override { - return "The input array X must be set"; - } - const char *NotHasOutError() const override { - return "The output tensor out must be set"; - } -}; +class ReadFromArrayInferShape : public WriteToArrayInferShape {}; template class WriteToArrayGradMaker : public framework::SingleGradOpMaker {