提交 969ad966 编写于 作者: X Xin Pan

all converted

test=develop
上级 a872eb90
...@@ -476,6 +476,28 @@ const Tensor* ExecutionContext::LegacyInput<Tensor>( ...@@ -476,6 +476,28 @@ const Tensor* ExecutionContext::LegacyInput<Tensor>(
template <> template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const { const std::string& name) const {
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
}
const std::vector<Variable*>& vars = it->second;
std::vector<const Tensor*> res;
res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> const Tensor* {
if (var == nullptr) return nullptr;
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"should be LoDTensor, but the received type is %s",
var->Type().name());
return &(var->Get<LoDTensor>());
});
return res;
}
template <>
const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
const std::string& name) const {
auto names = op().Inputs(name); auto names = op().Inputs(name);
std::vector<const Tensor*> res; std::vector<const Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
......
...@@ -197,8 +197,31 @@ class ExecutionContext { ...@@ -197,8 +197,31 @@ class ExecutionContext {
const std::vector<const Variable*> MultiInputVar( const std::vector<const Variable*> MultiInputVar(
const std::string& name) const { const std::string& name) const {
auto names = op_.Inputs(name); auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
}
std::vector<const Variable*> res; std::vector<const Variable*> res;
res.reserve(it->second.size());
std::transform(it->second.begin(), it->second.end(),
std::back_inserter(res),
[this](Variable* var) { return var; });
return res;
}
std::vector<Variable*> MultiOutputVar(const std::string& name) const {
auto names = op_.Outputs(name);
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) {
return {};
}
return it->second;
}
const std::vector<Variable*> LegacyMultiInputVar(
const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<Variable*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) { [this](const std::string& name) {
...@@ -208,7 +231,7 @@ class ExecutionContext { ...@@ -208,7 +231,7 @@ class ExecutionContext {
return res; return res;
} }
std::vector<Variable*> MultiOutputVar(const std::string& name) const { std::vector<Variable*> LegacyMultiOutputVar(const std::string& name) const {
auto names = op_.Outputs(name); auto names = op_.Outputs(name);
std::vector<Variable*> res; std::vector<Variable*> res;
res.reserve(names.size()); res.reserve(names.size());
...@@ -250,6 +273,38 @@ class ExecutionContext { ...@@ -250,6 +273,38 @@ class ExecutionContext {
template <typename T> template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const { const std::vector<const T*> MultiInput(const std::string& name) const {
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) {
return {};
}
const std::vector<Variable*>& vars = it->second;
std::vector<const T*> res;
res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> const T* {
return var == nullptr ? nullptr : &var->Get<T>();
});
return res;
}
template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const {
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) {
return {};
}
const std::vector<Variable*>& vars = it->second;
std::vector<T*> res;
res.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(res),
[&](Variable* var) -> T* {
return var == nullptr ? nullptr : var->GetMutable<T>();
});
return res;
}
template <typename T>
const std::vector<const T*> LegacyMultiInput(const std::string& name) const {
auto names = op_.Inputs(name); auto names = op_.Inputs(name);
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
...@@ -262,7 +317,7 @@ class ExecutionContext { ...@@ -262,7 +317,7 @@ class ExecutionContext {
} }
template <typename T> template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const { std::vector<T*> LegacyMultiOutput(const std::string& name) const {
auto names = op_.Outputs(name); auto names = op_.Outputs(name);
std::vector<T*> res; std::vector<T*> res;
res.reserve(names.size()); res.reserve(names.size());
...@@ -321,6 +376,10 @@ template <> ...@@ -321,6 +376,10 @@ template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const; const std::string& name) const;
template <>
const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
const std::string& name) const;
template <> template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const; Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册