diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 4b520a393f2ed217feb18937684d5feeea0923b9..fec311e3ee3aa94bbd640a8d4a85840d96b3af43 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -476,6 +476,28 @@ const Tensor* ExecutionContext::LegacyInput( template <> const std::vector ExecutionContext::MultiInput( const std::string& name) const { + auto it = ctx_.inputs.find(name); + if (it == ctx_.inputs.end()) { + return {}; + } + const std::vector& vars = it->second; + std::vector res; + res.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(res), + [&](Variable* var) -> const Tensor* { + if (var == nullptr) return nullptr; + PADDLE_ENFORCE( + var->IsType(), + "should be LoDTensor, but the received type is %s", + var->Type().name()); + return &(var->Get()); + }); + return res; +} + +template <> +const std::vector ExecutionContext::LegacyMultiInput( + const std::string& name) const { auto names = op().Inputs(name); std::vector res; res.reserve(names.size()); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 39190d07b4ccdd5ffd03e2d50bb0e577ac00af75..1fe2daacf1369902cde732422b4e65c3d156250f 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -197,8 +197,31 @@ class ExecutionContext { const std::vector MultiInputVar( const std::string& name) const { - auto names = op_.Inputs(name); + auto it = ctx_.inputs.find(name); + if (it == ctx_.inputs.end()) { + return {}; + } std::vector res; + res.reserve(it->second.size()); + std::transform(it->second.begin(), it->second.end(), + std::back_inserter(res), + [this](Variable* var) { return var; }); + return res; + } + + std::vector MultiOutputVar(const std::string& name) const { + auto names = op_.Outputs(name); + auto it = ctx_.outputs.find(name); + if (it == ctx_.outputs.end()) { + return {}; + } + return it->second; + } + + const std::vector LegacyMultiInputVar( + const std::string& name) const { + auto names = op_.Inputs(name); + std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), [this](const std::string& name) { @@ -208,7 +231,7 @@ class ExecutionContext { return res; } - std::vector MultiOutputVar(const std::string& name) const { + std::vector LegacyMultiOutputVar(const std::string& name) const { auto names = op_.Outputs(name); std::vector res; res.reserve(names.size()); @@ -250,6 +273,38 @@ class ExecutionContext { template const std::vector MultiInput(const std::string& name) const { + auto it = ctx_.inputs.find(name); + if (it == ctx_.inputs.end()) { + return {}; + } + const std::vector& vars = it->second; + std::vector res; + res.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(res), + [&](Variable* var) -> const T* { + return var == nullptr ? nullptr : &var->Get(); + }); + return res; + } + + template + std::vector MultiOutput(const std::string& name) const { + auto it = ctx_.outputs.find(name); + if (it == ctx_.outputs.end()) { + return {}; + } + const std::vector& vars = it->second; + std::vector res; + res.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(res), + [&](Variable* var) -> T* { + return var == nullptr ? nullptr : var->GetMutable(); + }); + return res; + } + + template + const std::vector LegacyMultiInput(const std::string& name) const { auto names = op_.Inputs(name); std::vector res; res.reserve(names.size()); @@ -262,7 +317,7 @@ class ExecutionContext { } template - std::vector MultiOutput(const std::string& name) const { + std::vector LegacyMultiOutput(const std::string& name) const { auto names = op_.Outputs(name); std::vector res; res.reserve(names.size()); @@ -321,6 +376,10 @@ template <> const std::vector ExecutionContext::MultiInput( const std::string& name) const; +template <> +const std::vector ExecutionContext::LegacyMultiInput( + const std::string& name) const; + template <> Tensor* ExecutionContext::Output(const std::string& name) const;