diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc index 29e128ce7e5dc7b13c9938505c0ee506c7c48155..1d65c2bb4693c8c4a131c073637bc7f5b860ab64 100644 --- a/paddle/operators/feed_op.cc +++ b/paddle/operators/feed_op.cc @@ -32,8 +32,12 @@ class FeedOp : public framework::OperatorWithKernel { g_feed_variable->Get>(); PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); - auto in_dim = tensors[col].dims(); - ctx->SetOutputDim("Out", in_dim); + + auto& shape = ctx->Attrs().Get>("dims"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + ctx->SetOutputDim("Out", framework::make_ddim(shape_int64)); // TODO(qijun): need to handle LodTensor later }