diff --git a/paddle/fluid/operators/controlflow/fetch_op.cc b/paddle/fluid/operators/controlflow/fetch_op.cc index cd09ceb60f2ae23be5889faf1d93018872ceed7f..29be74a501f93b0e90a2a0a5c8f56636d985bfee 100644 --- a/paddle/fluid/operators/controlflow/fetch_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_op.cc @@ -31,24 +31,39 @@ class FetchOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { + OP_INOUT_CHECK(HasInputs("X"), "Input", "X", "Fetch"); + OP_INOUT_CHECK(HasOutputs("Out"), "Output", "Out", "Fetch"); + auto fetch_var_name = Input("X"); auto *fetch_var = scope.FindVar(fetch_var_name); - PADDLE_ENFORCE(fetch_var != nullptr, - "Cannot find fetch variable in scope, fetch_var_name is %s", - fetch_var_name); + PADDLE_ENFORCE_NOT_NULL( + fetch_var, + platform::errors::NotFound( + "Input variable(%s) cannot be found in scope for operator 'Fetch'.", + fetch_var_name)); - auto out_name = this->Output("Out"); + auto out_name = Output("Out"); auto *out_var = scope.FindVar(out_name); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot find out_var in scope, out_var_name is %s", - out_name); - - auto col = static_cast(Attr("col")); + PADDLE_ENFORCE_NOT_NULL(out_var, platform::errors::NotFound( + "Output variable(%s) cannot be found " + "in scope for operator 'Fetch'.", + out_name)); + + int col = Attr("col"); + PADDLE_ENFORCE_GE( + col, 0, platform::errors::InvalidArgument( + "Expected the column index (the attribute 'col' of " + "operator 'Fetch') of current fetching variable to be " + "no less than 0. But received column index = %d.", + col)); + + VLOG(3) << "Fetch variable " << fetch_var_name << " to variable " + << out_name << "'s " << col << " column."; auto *fetch_list = out_var->GetMutable(); auto &src_item = fetch_var->Get(); - if (col >= fetch_list->size()) { + if (static_cast(col) >= fetch_list->size()) { fetch_list->resize(col + 1); } auto &dst_item = fetch_list->at(col); @@ -81,17 +96,19 @@ class FetchOp : public framework::OperatorBase { dst_item.Resize({0}); } dst_item.set_lod(src_item.lod()); - - VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name; } }; class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "The input of fetch op"); - AddOutput("Out", "The output of fetch op"); - AddAttr("col", "(int) The column of fetch"); + AddInput("X", + "(LoDTensor) The resulted LoDTensor which is expected to return " + "to users."); + AddOutput("Out", + "(vector) A fetching list of LoDTensor which may have " + "different dimension, shape and data type."); + AddAttr("col", "(int) The column index of fetching object."); AddComment(R"DOC( Fetch Operator.