diff --git a/paddle/fluid/operators/controlflow/feed_op.cc b/paddle/fluid/operators/controlflow/feed_op.cc index 1e3af06cea2917e5a1fc184318b96b2349966a4a..b40d7b0ce8f12172dfeee5c13f02225b116e78de 100644 --- a/paddle/fluid/operators/controlflow/feed_op.cc +++ b/paddle/fluid/operators/controlflow/feed_op.cc @@ -28,35 +28,52 @@ class FeedOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - // get device context from pool - auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); + OP_INOUT_CHECK(HasInputs("X"), "Input", "X", "Feed"); + OP_INOUT_CHECK(HasOutputs("Out"), "Output", "Out", "Feed"); auto feed_var_name = Input("X"); auto *feed_var = scope.FindVar(feed_var_name); - - PADDLE_ENFORCE(feed_var != nullptr, - "Cannot find feed_var in scope, feed_var_name is %s", - feed_var_name); + PADDLE_ENFORCE_NOT_NULL( + feed_var, + platform::errors::NotFound( + "Input varibale(%s) cannot be found in scope for operator 'Feed'.", + feed_var_name)); auto out_name = this->Output("Out"); auto *out_var = scope.FindVar(out_name); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot find out_var in scope, out_var_name is %s", - out_name); + PADDLE_ENFORCE_NOT_NULL( + out_var, + platform::errors::NotFound( + "Output variable(%s) cannot be found in scope for operator 'Feed'", + out_name)); auto col = Attr("col"); + PADDLE_ENFORCE_GE(col, 0, + platform::errors::InvalidArgument( + "Expected the column index (the attribute 'col' of " + "operator 'Feed') of current feeding variable to be " + "no less than 0. But received column index = %d.", + col)); - VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var " - << out_name; + VLOG(3) << "Feed variable " << feed_var_name << "'s " << col + << " column to variable " << out_name; auto &feed_list = feed_var->Get(); - PADDLE_ENFORCE_LT(static_cast(col), feed_list.size()); + PADDLE_ENFORCE_LT( + static_cast(col), feed_list.size(), + platform::errors::InvalidArgument( + "The column index of current feeding variable is expected to be " + "less than the length of feeding list. But received column index = " + "%d, the length of feeding list = %d", + col, feed_list.size())); + auto &feed_item = feed_list.at(static_cast(col)); auto *out_item = out_var->GetMutable(); if (platform::is_same_place(feed_item.place(), place)) { out_item->ShareDataWith(feed_item); } else { + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); framework::TensorCopy(feed_item, place, *dev_ctx, out_item); } out_item->set_lod(feed_item.lod()); @@ -66,9 +83,13 @@ class FeedOp : public framework::OperatorBase { class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "The input of feed op"); - AddOutput("Out", "The output of feed op"); - AddAttr("col", "(int) The column of feed"); + AddInput("X", + "(vector) A feeding list of LoDTensor, which may have " + "different dimension and data type."); + AddOutput("Out", + "(LoDTensor) The LoDTensor which is a copy of the col-th feeding " + "object."); + AddAttr("col", "(int) The column index of current feeding object."); AddComment(R"DOC( Feed Operator.