提交 5031c93a 编写于 作者: Y Yi Wang

Pass test

上级 c5afddc6
......@@ -8,7 +8,9 @@ cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)
cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(attr_type SRCS attr_type.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
......
......@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/scope.h"
#include "paddle/string/printf.h"
namespace paddle {
namespace framework {
Scope::~Scope() {
for (Variable* v : vars_) delete v;
for (auto& kv : vars_) delete kv.second;
for (Scope* s : kids_) delete s;
}
......@@ -29,28 +29,32 @@ Scope& Scope::NewScope() {
}
Variable* Scope::NewVar(const std::string& name) {
atuo iter = vars_.find(name);
auto iter = vars_.find(name);
if (iter != vars_.end()) {
return iter.second->get();
return iter->second;
}
Variable* v = new Variable();
v->name_ = name;
var_[name] = v;
vars_[name] = v;
v->name_ = &(vars_.find(name)->first);
return v;
}
Variable* Scope::NewVar() {
return NewVar(string.Sprintf("%p.%d", this, vars_.size()));
return NewVar(string::Sprintf("%p.%d", this, vars_.size()));
}
Variable* Scope::FindVar(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) return it->second.get();
if (it != vars_.end()) return it->second;
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
}
Scope* Scope::FindScope(const Variable* var) const {
if (FindVar(var->name_) != nullptr) return this;
Scope* Scope::FindScope(const Variable* var) {
for (auto& kv : vars_) {
if (kv.second == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}
......
......@@ -53,10 +53,7 @@ class Scope {
Variable* FindVar(const std::string& name) const;
// Find the scope or an ancestor scope that contains the given variable.
Scope* FindScope(const Variable* var) const;
// Returns the name of a variable in this scope.
std::string VarName(const Variable* var) const { return var->name_; }
Scope* FindScope(const Variable* var);
private:
// Call Scope::NewScope for a sub-scope.
......
......@@ -49,8 +49,8 @@ TEST(Scope, FindVar) {
TEST(Scope, FindScope) {
Scope s;
Scope& ss = s.NewScope();
s.NewVar("a");
Variable* v = s.NewVar("a");
EXPECT_EQ(&s, s.FindVar("a"));
EXPECT_EQ(&s, ss.FindVar("a"));
EXPECT_EQ(&s, s.FindScope(v));
EXPECT_EQ(&s, ss.FindScope(v));
}
......@@ -17,7 +17,6 @@
#include <typeinfo>
#include "paddle/platform/assert.h"
#include "paddle/string/piece.h"
namespace paddle {
namespace framework {
......@@ -76,7 +75,7 @@ class Variable {
// the name; in the latter case, the variable should be identified
// by its address but not the unreadable name.
friend class Scope;
string::Piece name_;
const std::string* name_;
};
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册