提交 c5afddc6 编写于 作者: Y Yi Wang

Rewrite Scope

上级 0973c2c9
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/scope.h"
namespace paddle {
namespace framework {
Scope::~Scope() {
for (Variable* v : vars_) delete v;
for (Scope* s : kids_) delete s;
}
Scope& Scope::NewScope() {
kids_.push_back(new Scope(this));
return *kids_.back();
}
Variable* Scope::NewVar(const std::string& name) {
atuo iter = vars_.find(name);
if (iter != vars_.end()) {
return iter.second->get();
}
Variable* v = new Variable();
v->name_ = name;
var_[name] = v;
return v;
}
Variable* Scope::NewVar() {
return NewVar(string.Sprintf("%p.%d", this, vars_.size()));
}
Variable* Scope::FindVar(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) return it->second.get();
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
}
Scope* Scope::FindScope(const Variable* var) const {
if (FindVar(var->name_) != nullptr) return this;
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}
} // namespace framework
} // namespace paddle
......@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once
#include <list>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/variable.h"
......@@ -35,73 +35,36 @@ class Scope;
*/
class Scope {
public:
/**
* @brief Initialize s Scope without parent.
*/
Scope() {}
~Scope();
/**
* @brief Initialize a Scope with parent.
*/
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {}
// Create a sub-scope. Returns a reference other than a pointer so
// to prevent from manual deletion.
Scope& NewScope();
/**
* @brief Create Variable
*
* Create Variable in this Scope. Return the exist one if Variable already
* been created.
*/
Variable* CreateVariable(const std::string& name) {
auto var = GetVariable(name);
if (var) {
return var;
} else {
auto ptr = new Variable();
name_to_var_[name] = std::unique_ptr<Variable>(ptr);
var_to_name_[ptr] = name;
return GetVariable(name);
}
}
/**
* @brief Get Variable.
*
* Get Variable from this Scope, this function will recursive find Variable
* from it's parent scope. Return nullptr if not found.
*/
Variable* GetVariable(const std::string& name) const {
auto it = name_to_var_.find(name);
if (it != name_to_var_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
return nullptr;
}
}
/**
* @brief If this scope has a Var named name.
*
* Find if there is a Variable in this scope and it's parent scope
*/
bool HasVariable(const std::string& name) const {
return (name_to_var_.find(name) != name_to_var_.end() ||
(parent_ && parent_->HasVariable(name)));
}
std::string GetVariableName(Variable* const var) const {
try {
return var_to_name_.at(var);
} catch (...) {
return "";
}
}
// Create a variable with given name if it doesn't exist.
Variable* NewVar(const std::string& name);
// Create a variable with a scope-unique name.
Variable* NewVar();
// Find a variable in the scope or any of its ancestors. Returns
// nullptr if cannot find.
Variable* FindVar(const std::string& name) const;
// Find the scope or an ancestor scope that contains the given variable.
Scope* FindScope(const Variable* var) const;
// Returns the name of a variable in this scope.
std::string VarName(const Variable* var) const { return var->name_; }
private:
std::unordered_map<Variable*, std::string> var_to_name_;
std::unordered_map<std::string, std::unique_ptr<Variable>> name_to_var_;
std::shared_ptr<Scope> parent_{nullptr};
// Call Scope::NewScope for a sub-scope.
explicit Scope(Scope* parent) : parent_(parent) {}
std::map<std::string, Variable*> vars_;
std::list<Scope*> kids_;
Scope* parent_{nullptr};
};
} // namespace framework
......
......@@ -15,49 +15,42 @@ limitations under the License. */
#include "paddle/framework/scope.h"
#include "gtest/gtest.h"
TEST(Scope, Create) {
using paddle::framework::Scope;
using paddle::framework::Variable;
using paddle::framework::Scope;
using paddle::framework::Variable;
auto scope = std::make_shared<Scope>();
TEST(Scope, VarsShadowing) {
Scope s;
Scope& ss1 = s.NewScope();
Scope& ss2 = s.NewScope();
Variable* var0 = scope->CreateVariable("");
EXPECT_NE(var0, nullptr);
Variable* v0 = s.NewVar("a");
Variable* v1 = ss1.NewVar("a");
/// GetVariable will return nullptr if not exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var1, nullptr);
EXPECT_NE(v0, v1);
/// CreateVariable will return one.
Variable* var2 = scope->CreateVariable("a");
EXPECT_NE(var2, nullptr);
/// Get the created variable.
Variable* var3 = scope->GetVariable("a");
EXPECT_EQ(var2, var3);
EXPECT_EQ(v0, s.FindVar("a"));
EXPECT_EQ(v1, ss1.FindVar("a"));
EXPECT_EQ(v0, ss2.FindVar("a"));
}
/// CreateVariable will just return the variable if it's
/// already exist.
Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2);
TEST(Scope, FindVar) {
Scope s;
Scope& ss = s.NewScope();
EXPECT_EQ("a", scope->GetVariableName(var4));
Scope scope2;
auto var = scope2.CreateVariable("tmp");
EXPECT_EQ("", scope->GetVariableName(var));
}
EXPECT_EQ(nullptr, s.FindVar("a"));
EXPECT_EQ(nullptr, ss.FindVar("a"));
TEST(Scope, Parent) {
using paddle::framework::Scope;
using paddle::framework::Variable;
ss.NewVar("a");
auto parent_scope = std::make_shared<Scope>();
auto scope = std::make_shared<Scope>(parent_scope);
EXPECT_EQ(nullptr, s.FindVar("a"));
EXPECT_NE(nullptr, ss.FindVar("a"));
}
Variable* var0 = parent_scope->CreateVariable("a");
EXPECT_NE(var0, nullptr);
TEST(Scope, FindScope) {
Scope s;
Scope& ss = s.NewScope();
s.NewVar("a");
/// GetVariable will get Variable from parent scope if exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var0, var1);
EXPECT_EQ(&s, s.FindVar("a"));
EXPECT_EQ(&s, ss.FindVar("a"));
}
......@@ -17,6 +17,7 @@
#include <typeinfo>
#include "paddle/platform/assert.h"
#include "paddle/string/piece.h"
namespace paddle {
namespace framework {
......@@ -65,6 +66,17 @@ class Variable {
std::unique_ptr<Placeholder>
holder_; // pointers to a PlaceholderImpl object indeed.
// name_ is only meaningful with a Scope and accessible by it.
//
// NOTE: Please don't expose name_ by adding methods like
// Variable::Name or Scope::VarName! A variable could have a human
// readable name or an auto-generated scope-unique name. In the
// former case, the caller knows the name and doesn't need to access
// the name; in the latter case, the variable should be identified
// by its address but not the unreadable name.
friend class Scope;
string::Piece name_;
};
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册