From 975a51294e20c122e7143a232261d4fd49ac5643 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 9 Oct 2017 23:55:35 -0700 Subject: [PATCH] infer feed operator output variable shape with dims attribute --- paddle/operators/feed_op.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index 29e128ce7e5..1d65c2bb469 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -32,8 +32,12 @@ class FeedOp : public framework::OperatorWithKernel { g_feed_variable->Get>(); PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); - auto in_dim = tensors[col].dims(); - ctx->SetOutputDim("Out", in_dim); + + auto& shape = ctx->Attrs().Get>("dims"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + ctx->SetOutputDim("Out", framework::make_ddim(shape_int64)); // TODO(qijun): need to handle LodTensor later } -- GitLab