未验证 提交 55d0c8fd 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance the error message of feed_op. (#23526)

上级 71b5f1d2
...@@ -28,35 +28,52 @@ class FeedOp : public framework::OperatorBase { ...@@ -28,35 +28,52 @@ class FeedOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
// get device context from pool OP_INOUT_CHECK(HasInputs("X"), "Input", "X", "Feed");
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); OP_INOUT_CHECK(HasOutputs("Out"), "Output", "Out", "Feed");
auto feed_var_name = Input("X"); auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name); auto *feed_var = scope.FindVar(feed_var_name);
PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(feed_var != nullptr, feed_var,
"Cannot find feed_var in scope, feed_var_name is %s", platform::errors::NotFound(
feed_var_name); "Input varibale(%s) cannot be found in scope for operator 'Feed'.",
feed_var_name));
auto out_name = this->Output("Out"); auto out_name = this->Output("Out");
auto *out_var = scope.FindVar(out_name); auto *out_var = scope.FindVar(out_name);
PADDLE_ENFORCE(out_var != nullptr, PADDLE_ENFORCE_NOT_NULL(
"Cannot find out_var in scope, out_var_name is %s", out_var,
out_name); platform::errors::NotFound(
"Output variable(%s) cannot be found in scope for operator 'Feed'",
out_name));
auto col = Attr<int>("col"); auto col = Attr<int>("col");
PADDLE_ENFORCE_GE(col, 0,
platform::errors::InvalidArgument(
"Expected the column index (the attribute 'col' of "
"operator 'Feed') of current feeding variable to be "
"no less than 0. But received column index = %d.",
col));
VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var " VLOG(3) << "Feed variable " << feed_var_name << "'s " << col
<< out_name; << " column to variable " << out_name;
auto &feed_list = feed_var->Get<framework::FeedFetchList>(); auto &feed_list = feed_var->Get<framework::FeedFetchList>();
PADDLE_ENFORCE_LT(static_cast<size_t>(col), feed_list.size()); PADDLE_ENFORCE_LT(
static_cast<size_t>(col), feed_list.size(),
platform::errors::InvalidArgument(
"The column index of current feeding variable is expected to be "
"less than the length of feeding list. But received column index = "
"%d, the length of feeding list = %d",
col, feed_list.size()));
auto &feed_item = feed_list.at(static_cast<size_t>(col)); auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>(); auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
if (platform::is_same_place(feed_item.place(), place)) { if (platform::is_same_place(feed_item.place(), place)) {
out_item->ShareDataWith(feed_item); out_item->ShareDataWith(feed_item);
} else { } else {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
framework::TensorCopy(feed_item, place, *dev_ctx, out_item); framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
} }
out_item->set_lod(feed_item.lod()); out_item->set_lod(feed_item.lod());
...@@ -66,9 +83,13 @@ class FeedOp : public framework::OperatorBase { ...@@ -66,9 +83,13 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker { class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "The input of feed op"); AddInput("X",
AddOutput("Out", "The output of feed op"); "(vector<LoDTensor>) A feeding list of LoDTensor, which may have "
AddAttr<int>("col", "(int) The column of feed"); "different dimension and data type.");
AddOutput("Out",
"(LoDTensor) The LoDTensor which is a copy of the col-th feeding "
"object.");
AddAttr<int>("col", "(int) The column index of current feeding object.");
AddComment(R"DOC( AddComment(R"DOC(
Feed Operator. Feed Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册