From 5031c93aebad67d7d53ead8a766b81512e3296de Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Sun, 30 Jul 2017 14:25:03 -0700 Subject: [PATCH] Pass test --- paddle/framework/CMakeLists.txt | 4 +++- paddle/framework/scope.cc | 24 ++++++++++++++---------- paddle/framework/scope.h | 5 +---- paddle/framework/scope_test.cc | 6 +++--- paddle/framework/variable.h | 3 +-- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 21cb7c7265..b74fa3581f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -8,7 +8,9 @@ cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(variable_test SRCS variable_test.cc) -cc_test(scope_test SRCS scope_test.cc) + +cc_library(scope SRCS scope.cc) +cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto DEPS attr_type) diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index ad5360d98f..3c9ec92d72 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" +#include "paddle/string/printf.h" namespace paddle { namespace framework { Scope::~Scope() { - for (Variable* v : vars_) delete v; - + for (auto& kv : vars_) delete kv.second; for (Scope* s : kids_) delete s; } @@ -29,28 +29,32 @@ Scope& Scope::NewScope() { } Variable* Scope::NewVar(const std::string& name) { - atuo iter = vars_.find(name); + auto iter = vars_.find(name); if (iter != vars_.end()) { - return iter.second->get(); + return iter->second; } Variable* v = new Variable(); - v->name_ = name; - var_[name] = v; + vars_[name] = v; + v->name_ = &(vars_.find(name)->first); return v; } Variable* Scope::NewVar() { - return NewVar(string.Sprintf("%p.%d", this, vars_.size())); + return NewVar(string::Sprintf("%p.%d", this, vars_.size())); } Variable* Scope::FindVar(const std::string& name) const { auto it = vars_.find(name); - if (it != vars_.end()) return it->second.get(); + if (it != vars_.end()) return it->second; return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); } -Scope* Scope::FindScope(const Variable* var) const { - if (FindVar(var->name_) != nullptr) return this; +Scope* Scope::FindScope(const Variable* var) { + for (auto& kv : vars_) { + if (kv.second == var) { + return this; + } + } return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); } diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index b145ae3a4d..9b4fffb9a6 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -53,10 +53,7 @@ class Scope { Variable* FindVar(const std::string& name) const; // Find the scope or an ancestor scope that contains the given variable. - Scope* FindScope(const Variable* var) const; - - // Returns the name of a variable in this scope. - std::string VarName(const Variable* var) const { return var->name_; } + Scope* FindScope(const Variable* var); private: // Call Scope::NewScope for a sub-scope. diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index 6f5e735d82..9d51e355b0 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -49,8 +49,8 @@ TEST(Scope, FindVar) { TEST(Scope, FindScope) { Scope s; Scope& ss = s.NewScope(); - s.NewVar("a"); + Variable* v = s.NewVar("a"); - EXPECT_EQ(&s, s.FindVar("a")); - EXPECT_EQ(&s, ss.FindVar("a")); + EXPECT_EQ(&s, s.FindScope(v)); + EXPECT_EQ(&s, ss.FindScope(v)); } diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index 68a443a06e..10a3866b85 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -17,7 +17,6 @@ #include #include "paddle/platform/assert.h" -#include "paddle/string/piece.h" namespace paddle { namespace framework { @@ -76,7 +75,7 @@ class Variable { // the name; in the latter case, the variable should be identified // by its address but not the unreadable name. friend class Scope; - string::Piece name_; + const std::string* name_; }; } // namespace framework -- GitLab