From 5639f49b16bcc03c758c7a6c1574c7371ef26dd6 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Fri, 25 Jan 2019 13:12:36 +0000 Subject: [PATCH] test=develop, fix/multi_output_support_imperative --- paddle/fluid/framework/operator.cc | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ee9f6a48054..031e7191396 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -555,18 +555,17 @@ Tensor* ExecutionContext::LegacyOutput(const std::string& name) const { template <> std::vector ExecutionContext::MultiOutput( const std::string& name) const { - auto names = op().Outputs(name); + auto it = ctx_.outputs.find(name); + if (it == ctx_.outputs.end()) { + return {}; + } + const std::vector& vars = it->second; std::vector res; - res.reserve(names.size()); - std::transform(names.begin(), names.end(), std::back_inserter(res), - [&](const std::string& sub_name) -> Tensor* { - auto var = scope_.FindVar(sub_name); - if (var == nullptr) return nullptr; - PADDLE_ENFORCE( - var->IsType(), - "%s should be LoDTensor, but the received type is %s", - sub_name, ToTypeName(var->Type())); - return var->GetMutable(); + res.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(res), + [&](Variable* var) -> Tensor* { + return var == nullptr ? nullptr + : var->GetMutable(); }); return res; } -- GitLab