提交 975a5129 编写于 作者: Q qijun

infer feed operator output variable shape with dims attribute

上级 2fc7fc7a
...@@ -32,8 +32,12 @@ class FeedOp : public framework::OperatorWithKernel { ...@@ -32,8 +32,12 @@ class FeedOp : public framework::OperatorWithKernel {
g_feed_variable->Get<std::vector<framework::Tensor>>(); g_feed_variable->Get<std::vector<framework::Tensor>>();
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col)); PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
auto in_dim = tensors[col].dims();
ctx->SetOutputDim("Out", in_dim); auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
ctx->SetOutputDim("Out", framework::make_ddim(shape_int64));
// TODO(qijun): need to handle LodTensor later // TODO(qijun): need to handle LodTensor later
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册