From 0ceeacbe455c3e74431dfd92c4d837a52869d424 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 24 Jul 2017 18:12:50 +0800 Subject: [PATCH] Make Scope can lookup variable name by variable * Refine unittest also --- paddle/framework/scope.h | 13 ++++++++++++- paddle/framework/scope_test.cc | 2 ++ paddle/pybind/pybind.cc | 10 +++++++++- python/paddle/v2/framework/network.py | 14 ++++---------- python/paddle/v2/framework/tests/test_network.py | 9 +++++++++ 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 79c9ffd1a67..cbbccf465d4 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -56,7 +56,9 @@ class Scope { if (var) { return var; } else { - vars_[name] = std::unique_ptr(new Variable()); + auto ptr = new Variable(); + vars_[name] = std::unique_ptr(ptr); + var_names_[ptr] = name; return GetVariable(name); } } @@ -88,7 +90,16 @@ class Scope { (parent_ && parent_->HasVariable(name))); } + std::string GetVariableName(Variable* const var) const { + try { + return var_names_.at(var); + } catch (...) { + return ""; + } + } + private: + std::unordered_map var_names_; std::unordered_map> vars_; std::shared_ptr parent_{nullptr}; }; diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index df1afb200ce..51de74ddfef 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -40,6 +40,8 @@ TEST(Scope, Create) { /// already exist. Variable* var4 = scope->CreateVariable("a"); EXPECT_EQ(var4, var2); + + EXPECT_EQ("a", scope->GetVariableName(var4)); } TEST(Scope, Parent) { diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 43c52957a1e..35880041222 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -56,6 +56,11 @@ void ExposeOperator(ClassType& m) { .def("__str__", &ClassType::type::DebugString); } +static size_t UniqueIntegerGenerator() { + static std::atomic generator; + return generator.fetch_add(1); +} + PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of PaddlePaddle"); @@ -106,7 +111,8 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference) .def("create_var", &pd::Scope::CreateVariable, - py::return_value_policy::reference); + py::return_value_policy::reference) + .def("get_var_name", &pd::Scope::GetVariableName); //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. @@ -166,5 +172,7 @@ All parameter, weight, gradient are variables in Paddle. .def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); }); ExposeOperator(net); + m.def("unique_integer", UniqueIntegerGenerator); + return m.ptr(); } diff --git a/python/paddle/v2/framework/network.py b/python/paddle/v2/framework/network.py index ade7e4b85e6..c85e87413ef 100644 --- a/python/paddle/v2/framework/network.py +++ b/python/paddle/v2/framework/network.py @@ -29,35 +29,31 @@ class NetworkFunctor(object): if ipt in kwargs: var = kwargs[ipt] if isinstance(var, basestring): - var_name = var var = create_var(var) - self.net.var_name_map[var] = var_name if not isinstance(var, core.Variable): raise TypeError( "Input of op creation must be string or variable") - kwargs[ipt] = self.net.var_name_map[var] + kwargs[ipt] = get_cur_scope().get_var_name(var) notemp_outputs = self.func.all_not_temp_output_args for name in notemp_outputs: if name not in kwargs: kwargs[ - name] = self.func.__name__ + "@OUT@%d" % self.net.generate_idx - self.net.generate_idx += 1 + name] = self.func.__name__ + "@OUT@%d" % core.unique_integer( + ) outputs = self.func.all_output_args for opt in outputs: if opt in kwargs: var = kwargs[opt] if isinstance(var, basestring): - var_name = var var = create_var(var) - self.net.var_name_map[var] = var_name if not isinstance(var, core.Variable): raise TypeError( "Output of op creation must be string or variable") - kwargs[opt] = self.net.var_name_map[var] + kwargs[opt] = get_cur_scope().get_var_name(var) op = self.func(**kwargs) @@ -93,8 +89,6 @@ class Network(object): self.net = core.Net.create() funcs = (func_name for func_name in dir(op_creations) if not func_name.startswith("__")) - self.generate_idx = 0 - self.var_name_map = dict() # TODO(yuyang18): This code can work, but do not generate a good # docstring, try to give a better way generate function in runtime diff --git a/python/paddle/v2/framework/tests/test_network.py b/python/paddle/v2/framework/tests/test_network.py index 310290e34b0..457f8f13a62 100644 --- a/python/paddle/v2/framework/tests/test_network.py +++ b/python/paddle/v2/framework/tests/test_network.py @@ -18,6 +18,15 @@ class TestNet(unittest.TestCase): Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1). ''', str(net)) + net2 = Network() + tmp = net2.add_two(X="X", Y="Y") + self.assertTrue(isinstance(tmp, core.Variable)) + net2.complete_add_op() + self.assertEqual( + '''Op(naive_net), inputs:(X, Y), outputs:(add_two@OUT@2). + Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2). +''', str(net2)) + if __name__ == '__main__': unittest.main() -- GitLab