提交 d1000623 编写于 作者: Y Yi Wang

Update usage of Scope

上级 5031c93a
...@@ -119,19 +119,19 @@ class KernelContext { ...@@ -119,19 +119,19 @@ class KernelContext {
: op_(*op), scope_(scope), device_context_(device_context) {} : op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const { const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]); return scope_->FindVar(op_.inputs_[index]);
} }
Variable* Output(int index) const { Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]); return scope_->FindVar(op_.outputs_[index]);
} }
const Variable* Input(const std::string& name) const { const Variable* Input(const std::string& name) const {
return scope_->GetVariable(op_.Input(name)); return scope_->FindVar(op_.Input(name));
} }
const Variable* Output(const std::string& name) const { const Variable* Output(const std::string& name) const {
return scope_->GetVariable(op_.Output(name)); return scope_->FindVar(op_.Output(name));
} }
const std::vector<const Variable*> Inputs(const std::string& name) const { const std::vector<const Variable*> Inputs(const std::string& name) const {
...@@ -139,7 +139,7 @@ class KernelContext { ...@@ -139,7 +139,7 @@ class KernelContext {
std::vector<const Variable*> res; std::vector<const Variable*> res;
std::transform( std::transform(
names.begin(), names.end(), res.begin(), names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); }); [this](const std::string& name) { return scope_->FindVar(name); });
return res; return res;
} }
...@@ -148,7 +148,7 @@ class KernelContext { ...@@ -148,7 +148,7 @@ class KernelContext {
std::vector<const Variable*> res; std::vector<const Variable*> res;
std::transform( std::transform(
names.begin(), names.end(), res.begin(), names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->GetVariable(name); }); [this](const std::string& name) { return scope_->FindVar(name); });
return res; return res;
} }
...@@ -244,7 +244,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -244,7 +244,7 @@ class OperatorWithKernel : public OperatorBase {
container->reserve(var_names.size()); container->reserve(var_names.size());
VarToTensor<T> convert; VarToTensor<T> convert;
for (auto& name : var_names) { for (auto& name : var_names) {
auto var = scope->GetVariable(name); auto var = scope->FindVar(name);
if (var != nullptr) { if (var != nullptr) {
container->push_back(convert(var)); container->push_back(convert(var));
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册