提交 c5afddc6 编写于 作者: Y Yi Wang

Rewrite Scope

上级 0973c2c9
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/scope.h"
namespace paddle {
namespace framework {
Scope::~Scope() {
for (Variable* v : vars_) delete v;
for (Scope* s : kids_) delete s;
}
Scope& Scope::NewScope() {
kids_.push_back(new Scope(this));
return *kids_.back();
}
Variable* Scope::NewVar(const std::string& name) {
atuo iter = vars_.find(name);
if (iter != vars_.end()) {
return iter.second->get();
}
Variable* v = new Variable();
v->name_ = name;
var_[name] = v;
return v;
}
Variable* Scope::NewVar() {
return NewVar(string.Sprintf("%p.%d", this, vars_.size()));
}
Variable* Scope::FindVar(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) return it->second.get();
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
}
Scope* Scope::FindScope(const Variable* var) const {
if (FindVar(var->name_) != nullptr) return this;
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}
} // namespace framework
} // namespace paddle
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <list>
#include <map>
#include <string> #include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/variable.h" #include "paddle/framework/variable.h"
...@@ -35,73 +35,36 @@ class Scope; ...@@ -35,73 +35,36 @@ class Scope;
*/ */
class Scope { class Scope {
public: public:
/**
* @brief Initialize s Scope without parent.
*/
Scope() {} Scope() {}
~Scope();
/** // Create a sub-scope. Returns a reference other than a pointer so
* @brief Initialize a Scope with parent. // to prevent from manual deletion.
*/ Scope& NewScope();
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {}
// Create a variable with given name if it doesn't exist.
/** Variable* NewVar(const std::string& name);
* @brief Create Variable
* // Create a variable with a scope-unique name.
* Create Variable in this Scope. Return the exist one if Variable already Variable* NewVar();
* been created.
*/ // Find a variable in the scope or any of its ancestors. Returns
Variable* CreateVariable(const std::string& name) { // nullptr if cannot find.
auto var = GetVariable(name); Variable* FindVar(const std::string& name) const;
if (var) {
return var; // Find the scope or an ancestor scope that contains the given variable.
} else { Scope* FindScope(const Variable* var) const;
auto ptr = new Variable();
name_to_var_[name] = std::unique_ptr<Variable>(ptr); // Returns the name of a variable in this scope.
var_to_name_[ptr] = name; std::string VarName(const Variable* var) const { return var->name_; }
return GetVariable(name);
}
}
/**
* @brief Get Variable.
*
* Get Variable from this Scope, this function will recursive find Variable
* from it's parent scope. Return nullptr if not found.
*/
Variable* GetVariable(const std::string& name) const {
auto it = name_to_var_.find(name);
if (it != name_to_var_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
return nullptr;
}
}
/**
* @brief If this scope has a Var named name.
*
* Find if there is a Variable in this scope and it's parent scope
*/
bool HasVariable(const std::string& name) const {
return (name_to_var_.find(name) != name_to_var_.end() ||
(parent_ && parent_->HasVariable(name)));
}
std::string GetVariableName(Variable* const var) const {
try {
return var_to_name_.at(var);
} catch (...) {
return "";
}
}
private: private:
std::unordered_map<Variable*, std::string> var_to_name_; // Call Scope::NewScope for a sub-scope.
std::unordered_map<std::string, std::unique_ptr<Variable>> name_to_var_; explicit Scope(Scope* parent) : parent_(parent) {}
std::shared_ptr<Scope> parent_{nullptr};
std::map<std::string, Variable*> vars_;
std::list<Scope*> kids_;
Scope* parent_{nullptr};
}; };
} // namespace framework } // namespace framework
......
...@@ -15,49 +15,42 @@ limitations under the License. */ ...@@ -15,49 +15,42 @@ limitations under the License. */
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
TEST(Scope, Create) { using paddle::framework::Scope;
using paddle::framework::Scope; using paddle::framework::Variable;
using paddle::framework::Variable;
auto scope = std::make_shared<Scope>(); TEST(Scope, VarsShadowing) {
Scope s;
Scope& ss1 = s.NewScope();
Scope& ss2 = s.NewScope();
Variable* var0 = scope->CreateVariable(""); Variable* v0 = s.NewVar("a");
EXPECT_NE(var0, nullptr); Variable* v1 = ss1.NewVar("a");
/// GetVariable will return nullptr if not exist. EXPECT_NE(v0, v1);
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var1, nullptr);
/// CreateVariable will return one. EXPECT_EQ(v0, s.FindVar("a"));
Variable* var2 = scope->CreateVariable("a"); EXPECT_EQ(v1, ss1.FindVar("a"));
EXPECT_NE(var2, nullptr); EXPECT_EQ(v0, ss2.FindVar("a"));
}
/// Get the created variable.
Variable* var3 = scope->GetVariable("a");
EXPECT_EQ(var2, var3);
/// CreateVariable will just return the variable if it's TEST(Scope, FindVar) {
/// already exist. Scope s;
Variable* var4 = scope->CreateVariable("a"); Scope& ss = s.NewScope();
EXPECT_EQ(var4, var2);
EXPECT_EQ("a", scope->GetVariableName(var4)); EXPECT_EQ(nullptr, s.FindVar("a"));
Scope scope2; EXPECT_EQ(nullptr, ss.FindVar("a"));
auto var = scope2.CreateVariable("tmp");
EXPECT_EQ("", scope->GetVariableName(var));
}
TEST(Scope, Parent) { ss.NewVar("a");
using paddle::framework::Scope;
using paddle::framework::Variable;
auto parent_scope = std::make_shared<Scope>(); EXPECT_EQ(nullptr, s.FindVar("a"));
auto scope = std::make_shared<Scope>(parent_scope); EXPECT_NE(nullptr, ss.FindVar("a"));
}
Variable* var0 = parent_scope->CreateVariable("a"); TEST(Scope, FindScope) {
EXPECT_NE(var0, nullptr); Scope s;
Scope& ss = s.NewScope();
s.NewVar("a");
/// GetVariable will get Variable from parent scope if exist. EXPECT_EQ(&s, s.FindVar("a"));
Variable* var1 = scope->GetVariable("a"); EXPECT_EQ(&s, ss.FindVar("a"));
EXPECT_EQ(var0, var1);
} }
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <typeinfo> #include <typeinfo>
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include "paddle/string/piece.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -65,6 +66,17 @@ class Variable { ...@@ -65,6 +66,17 @@ class Variable {
std::unique_ptr<Placeholder> std::unique_ptr<Placeholder>
holder_; // pointers to a PlaceholderImpl object indeed. holder_; // pointers to a PlaceholderImpl object indeed.
// name_ is only meaningful with a Scope and accessible by it.
//
// NOTE: Please don't expose name_ by adding methods like
// Variable::Name or Scope::VarName! A variable could have a human
// readable name or an auto-generated scope-unique name. In the
// former case, the caller knows the name and doesn't need to access
// the name; in the latter case, the variable should be identified
// by its address but not the unreadable name.
friend class Scope;
string::Piece name_;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册