From 754f0c68da61ae4b7a5a67cdc9d841159bd73fbe Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 25 Jul 2017 15:26:01 +0800 Subject: [PATCH] Fix unittest --- paddle/framework/scope.h | 16 ++++++++-------- paddle/framework/scope_test.cc | 3 +++ paddle/pybind/pybind.cc | 10 +--------- python/paddle/v2/framework/tests/test_network.py | 4 ++-- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index cbbccf465d..4faaf84144 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -57,8 +57,8 @@ class Scope { return var; } else { auto ptr = new Variable(); - vars_[name] = std::unique_ptr(ptr); - var_names_[ptr] = name; + name_to_var_[name] = std::unique_ptr(ptr); + var_to_name_[ptr] = name; return GetVariable(name); } } @@ -70,8 +70,8 @@ class Scope { * from it's parent scope. Return nullptr if not found. */ Variable* GetVariable(const std::string& name) const { - auto it = vars_.find(name); - if (it != vars_.end()) { + auto it = name_to_var_.find(name); + if (it != name_to_var_.end()) { return it->second.get(); } else if (parent_ != nullptr) { return parent_->GetVariable(name); @@ -86,21 +86,21 @@ class Scope { * Find if there is a Variable in this scope and it's parent scope */ bool HasVariable(const std::string& name) const { - return (vars_.find(name) != vars_.end() || + return (name_to_var_.find(name) != name_to_var_.end() || (parent_ && parent_->HasVariable(name))); } std::string GetVariableName(Variable* const var) const { try { - return var_names_.at(var); + return var_to_name_.at(var); } catch (...) { return ""; } } private: - std::unordered_map var_names_; - std::unordered_map> vars_; + std::unordered_map var_to_name_; + std::unordered_map> name_to_var_; std::shared_ptr parent_{nullptr}; }; diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index 51de74ddfe..ff069c7be0 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -42,6 +42,9 @@ TEST(Scope, Create) { EXPECT_EQ(var4, var2); EXPECT_EQ("a", scope->GetVariableName(var4)); + Scope scope2; + auto var = scope2.CreateVariable("tmp"); + EXPECT_EQ("", scope->GetVariableName(var)); } TEST(Scope, Parent) { diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3588004122..0b152d03c0 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -15,14 +15,6 @@ limitations under the License. */ #include #include #include -#include "paddle/framework/net.h" -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" -#include "paddle/framework/scope.h" -#include "paddle/pybind/tensor_bind.h" -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" #include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" @@ -160,7 +152,7 @@ All parameter, weight, gradient are variables in Paddle. net.def_static("create", []() -> std::shared_ptr { auto retv = std::make_shared(); - retv->type_ = "naive_net"; + retv->type_ = "plain_net"; return retv; }) .def("add_op", &pd::PlainNet::AddOp) diff --git a/python/paddle/v2/framework/tests/test_network.py b/python/paddle/v2/framework/tests/test_network.py index 457f8f13a6..6d53e233e9 100644 --- a/python/paddle/v2/framework/tests/test_network.py +++ b/python/paddle/v2/framework/tests/test_network.py @@ -11,7 +11,7 @@ class TestNet(unittest.TestCase): net.complete_add_op() self.assertTrue(isinstance(fc_out, core.Variable)) self.assertEqual( - '''Op(naive_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1). + '''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1). Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0). Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0). Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0). @@ -23,7 +23,7 @@ class TestNet(unittest.TestCase): self.assertTrue(isinstance(tmp, core.Variable)) net2.complete_add_op() self.assertEqual( - '''Op(naive_net), inputs:(X, Y), outputs:(add_two@OUT@2). + '''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2). Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2). ''', str(net2)) -- GitLab