未验证 提交 55d0c8fd 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance the error message of feed_op. (#23526)

上级 71b5f1d2
......@@ -28,35 +28,52 @@ class FeedOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get device context from pool
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
OP_INOUT_CHECK(HasInputs("X"), "Input", "X", "Feed");
OP_INOUT_CHECK(HasOutputs("Out"), "Output", "Out", "Feed");
auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);
PADDLE_ENFORCE(feed_var != nullptr,
"Cannot find feed_var in scope, feed_var_name is %s",
feed_var_name);
PADDLE_ENFORCE_NOT_NULL(
feed_var,
platform::errors::NotFound(
"Input varibale(%s) cannot be found in scope for operator 'Feed'.",
feed_var_name));
auto out_name = this->Output("Out");
auto *out_var = scope.FindVar(out_name);
PADDLE_ENFORCE(out_var != nullptr,
"Cannot find out_var in scope, out_var_name is %s",
out_name);
PADDLE_ENFORCE_NOT_NULL(
out_var,
platform::errors::NotFound(
"Output variable(%s) cannot be found in scope for operator 'Feed'",
out_name));
auto col = Attr<int>("col");
PADDLE_ENFORCE_GE(col, 0,
platform::errors::InvalidArgument(
"Expected the column index (the attribute 'col' of "
"operator 'Feed') of current feeding variable to be "
"no less than 0. But received column index = %d.",
col));
VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var "
<< out_name;
VLOG(3) << "Feed variable " << feed_var_name << "'s " << col
<< " column to variable " << out_name;
auto &feed_list = feed_var->Get<framework::FeedFetchList>();
PADDLE_ENFORCE_LT(static_cast<size_t>(col), feed_list.size());
PADDLE_ENFORCE_LT(
static_cast<size_t>(col), feed_list.size(),
platform::errors::InvalidArgument(
"The column index of current feeding variable is expected to be "
"less than the length of feeding list. But received column index = "
"%d, the length of feeding list = %d",
col, feed_list.size()));
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
if (platform::is_same_place(feed_item.place(), place)) {
out_item->ShareDataWith(feed_item);
} else {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
}
out_item->set_lod(feed_item.lod());
......@@ -66,9 +83,13 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op");
AddAttr<int>("col", "(int) The column of feed");
AddInput("X",
"(vector<LoDTensor>) A feeding list of LoDTensor, which may have "
"different dimension and data type.");
AddOutput("Out",
"(LoDTensor) The LoDTensor which is a copy of the col-th feeding "
"object.");
AddAttr<int>("col", "(int) The column index of current feeding object.");
AddComment(R"DOC(
Feed Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册