diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 21cb7c7265e0052630b68954fa25f9189e641e7b..b74fa3581f0ced53a6f7d21081791b8129714160 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -8,7 +8,9 @@ cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(variable_test SRCS variable_test.cc) -cc_test(scope_test SRCS scope_test.cc) + +cc_library(scope SRCS scope.cc) +cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto DEPS attr_type) diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index ad5360d98fbdfe9972df71b2dde10ba21b0cc837..3c9ec92d7200f2fd2d88627603637488ff2c1c19 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" +#include "paddle/string/printf.h" namespace paddle { namespace framework { Scope::~Scope() { - for (Variable* v : vars_) delete v; - + for (auto& kv : vars_) delete kv.second; for (Scope* s : kids_) delete s; } @@ -29,28 +29,32 @@ Scope& Scope::NewScope() { } Variable* Scope::NewVar(const std::string& name) { - atuo iter = vars_.find(name); + auto iter = vars_.find(name); if (iter != vars_.end()) { - return iter.second->get(); + return iter->second; } Variable* v = new Variable(); - v->name_ = name; - var_[name] = v; + vars_[name] = v; + v->name_ = &(vars_.find(name)->first); return v; } Variable* Scope::NewVar() { - return NewVar(string.Sprintf("%p.%d", this, vars_.size())); + return NewVar(string::Sprintf("%p.%d", this, vars_.size())); } Variable* Scope::FindVar(const std::string& name) const { auto it = vars_.find(name); - if (it != vars_.end()) return it->second.get(); + if (it != vars_.end()) return it->second; return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); } -Scope* Scope::FindScope(const Variable* var) const { - if (FindVar(var->name_) != nullptr) return this; +Scope* Scope::FindScope(const Variable* var) { + for (auto& kv : vars_) { + if (kv.second == var) { + return this; + } + } return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); } diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index b145ae3a4d4e1db12dc5ac6c3e2d7396b1231cfb..9b4fffb9a62390493941fceea208c32967081d1c 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -53,10 +53,7 @@ class Scope { Variable* FindVar(const std::string& name) const; // Find the scope or an ancestor scope that contains the given variable. - Scope* FindScope(const Variable* var) const; - - // Returns the name of a variable in this scope. - std::string VarName(const Variable* var) const { return var->name_; } + Scope* FindScope(const Variable* var); private: // Call Scope::NewScope for a sub-scope. diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index 6f5e735d82449c91d4e7e9c9560c98e27f7b282e..9d51e355b0f6336d2f875ff2d77266b261baf5ac 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -49,8 +49,8 @@ TEST(Scope, FindVar) { TEST(Scope, FindScope) { Scope s; Scope& ss = s.NewScope(); - s.NewVar("a"); + Variable* v = s.NewVar("a"); - EXPECT_EQ(&s, s.FindVar("a")); - EXPECT_EQ(&s, ss.FindVar("a")); + EXPECT_EQ(&s, s.FindScope(v)); + EXPECT_EQ(&s, ss.FindScope(v)); } diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index 68a443a06ea9ebf84aa06b930ed9ee9a0476b0cb..10a3866b850bfd9b50141d5f869f471dd59e1b8b 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -17,7 +17,6 @@ #include #include "paddle/platform/assert.h" -#include "paddle/string/piece.h" namespace paddle { namespace framework { @@ -76,7 +75,7 @@ class Variable { // the name; in the latter case, the variable should be identified // by its address but not the unreadable name. friend class Scope; - string::Piece name_; + const std::string* name_; }; } // namespace framework