提交 5031c93a 编写于 作者: Y Yi Wang

Pass test

上级 c5afddc6
...@@ -8,7 +8,9 @@ cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) ...@@ -8,7 +8,9 @@ cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)
cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(attr_type SRCS attr_type.proto) proto_library(attr_type SRCS attr_type.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attr_type) proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
......
...@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/string/printf.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Scope::~Scope() { Scope::~Scope() {
for (Variable* v : vars_) delete v; for (auto& kv : vars_) delete kv.second;
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
} }
...@@ -29,28 +29,32 @@ Scope& Scope::NewScope() { ...@@ -29,28 +29,32 @@ Scope& Scope::NewScope() {
} }
Variable* Scope::NewVar(const std::string& name) { Variable* Scope::NewVar(const std::string& name) {
atuo iter = vars_.find(name); auto iter = vars_.find(name);
if (iter != vars_.end()) { if (iter != vars_.end()) {
return iter.second->get(); return iter->second;
} }
Variable* v = new Variable(); Variable* v = new Variable();
v->name_ = name; vars_[name] = v;
var_[name] = v; v->name_ = &(vars_.find(name)->first);
return v; return v;
} }
Variable* Scope::NewVar() { Variable* Scope::NewVar() {
return NewVar(string.Sprintf("%p.%d", this, vars_.size())); return NewVar(string::Sprintf("%p.%d", this, vars_.size()));
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) return it->second.get(); if (it != vars_.end()) return it->second;
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
} }
Scope* Scope::FindScope(const Variable* var) const { Scope* Scope::FindScope(const Variable* var) {
if (FindVar(var->name_) != nullptr) return this; for (auto& kv : vars_) {
if (kv.second == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
} }
......
...@@ -53,10 +53,7 @@ class Scope { ...@@ -53,10 +53,7 @@ class Scope {
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
// Find the scope or an ancestor scope that contains the given variable. // Find the scope or an ancestor scope that contains the given variable.
Scope* FindScope(const Variable* var) const; Scope* FindScope(const Variable* var);
// Returns the name of a variable in this scope.
std::string VarName(const Variable* var) const { return var->name_; }
private: private:
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
......
...@@ -49,8 +49,8 @@ TEST(Scope, FindVar) { ...@@ -49,8 +49,8 @@ TEST(Scope, FindVar) {
TEST(Scope, FindScope) { TEST(Scope, FindScope) {
Scope s; Scope s;
Scope& ss = s.NewScope(); Scope& ss = s.NewScope();
s.NewVar("a"); Variable* v = s.NewVar("a");
EXPECT_EQ(&s, s.FindVar("a")); EXPECT_EQ(&s, s.FindScope(v));
EXPECT_EQ(&s, ss.FindVar("a")); EXPECT_EQ(&s, ss.FindScope(v));
} }
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <typeinfo> #include <typeinfo>
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include "paddle/string/piece.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -76,7 +75,7 @@ class Variable { ...@@ -76,7 +75,7 @@ class Variable {
// the name; in the latter case, the variable should be identified // the name; in the latter case, the variable should be identified
// by its address but not the unreadable name. // by its address but not the unreadable name.
friend class Scope; friend class Scope;
string::Piece name_; const std::string* name_;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册