提交 0ceeacbe 编写于 作者: Y Yu Yang

Make Scope can lookup variable name by variable

* Refine unittest also
上级 0ab678e9
...@@ -56,7 +56,9 @@ class Scope { ...@@ -56,7 +56,9 @@ class Scope {
if (var) { if (var) {
return var; return var;
} else { } else {
vars_[name] = std::unique_ptr<Variable>(new Variable()); auto ptr = new Variable();
vars_[name] = std::unique_ptr<Variable>(ptr);
var_names_[ptr] = name;
return GetVariable(name); return GetVariable(name);
} }
} }
...@@ -88,7 +90,16 @@ class Scope { ...@@ -88,7 +90,16 @@ class Scope {
(parent_ && parent_->HasVariable(name))); (parent_ && parent_->HasVariable(name)));
} }
std::string GetVariableName(Variable* const var) const {
try {
return var_names_.at(var);
} catch (...) {
return "";
}
}
private: private:
std::unordered_map<Variable*, std::string> var_names_;
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::shared_ptr<Scope> parent_{nullptr}; std::shared_ptr<Scope> parent_{nullptr};
}; };
......
...@@ -40,6 +40,8 @@ TEST(Scope, Create) { ...@@ -40,6 +40,8 @@ TEST(Scope, Create) {
/// already exist. /// already exist.
Variable* var4 = scope->CreateVariable("a"); Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2); EXPECT_EQ(var4, var2);
EXPECT_EQ("a", scope->GetVariableName(var4));
} }
TEST(Scope, Parent) { TEST(Scope, Parent) {
......
...@@ -56,6 +56,11 @@ void ExposeOperator(ClassType& m) { ...@@ -56,6 +56,11 @@ void ExposeOperator(ClassType& m) {
.def("__str__", &ClassType::type::DebugString); .def("__str__", &ClassType::type::DebugString);
} }
static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle"); py::module m("core", "C++ core of PaddlePaddle");
...@@ -106,7 +111,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -106,7 +111,8 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference) py::return_value_policy::reference)
.def("create_var", .def("create_var",
&pd::Scope::CreateVariable, &pd::Scope::CreateVariable,
py::return_value_policy::reference); py::return_value_policy::reference)
.def("get_var_name", &pd::Scope::GetVariableName);
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
...@@ -166,5 +172,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -166,5 +172,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); }); .def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
ExposeOperator(net); ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator);
return m.ptr(); return m.ptr();
} }
...@@ -29,35 +29,31 @@ class NetworkFunctor(object): ...@@ -29,35 +29,31 @@ class NetworkFunctor(object):
if ipt in kwargs: if ipt in kwargs:
var = kwargs[ipt] var = kwargs[ipt]
if isinstance(var, basestring): if isinstance(var, basestring):
var_name = var
var = create_var(var) var = create_var(var)
self.net.var_name_map[var] = var_name
if not isinstance(var, core.Variable): if not isinstance(var, core.Variable):
raise TypeError( raise TypeError(
"Input of op creation must be string or variable") "Input of op creation must be string or variable")
kwargs[ipt] = self.net.var_name_map[var] kwargs[ipt] = get_cur_scope().get_var_name(var)
notemp_outputs = self.func.all_not_temp_output_args notemp_outputs = self.func.all_not_temp_output_args
for name in notemp_outputs: for name in notemp_outputs:
if name not in kwargs: if name not in kwargs:
kwargs[ kwargs[
name] = self.func.__name__ + "@OUT@%d" % self.net.generate_idx name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
self.net.generate_idx += 1 )
outputs = self.func.all_output_args outputs = self.func.all_output_args
for opt in outputs: for opt in outputs:
if opt in kwargs: if opt in kwargs:
var = kwargs[opt] var = kwargs[opt]
if isinstance(var, basestring): if isinstance(var, basestring):
var_name = var
var = create_var(var) var = create_var(var)
self.net.var_name_map[var] = var_name
if not isinstance(var, core.Variable): if not isinstance(var, core.Variable):
raise TypeError( raise TypeError(
"Output of op creation must be string or variable") "Output of op creation must be string or variable")
kwargs[opt] = self.net.var_name_map[var] kwargs[opt] = get_cur_scope().get_var_name(var)
op = self.func(**kwargs) op = self.func(**kwargs)
...@@ -93,8 +89,6 @@ class Network(object): ...@@ -93,8 +89,6 @@ class Network(object):
self.net = core.Net.create() self.net = core.Net.create()
funcs = (func_name for func_name in dir(op_creations) funcs = (func_name for func_name in dir(op_creations)
if not func_name.startswith("__")) if not func_name.startswith("__"))
self.generate_idx = 0
self.var_name_map = dict()
# TODO(yuyang18): This code can work, but do not generate a good # TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime # docstring, try to give a better way generate function in runtime
......
...@@ -18,6 +18,15 @@ class TestNet(unittest.TestCase): ...@@ -18,6 +18,15 @@ class TestNet(unittest.TestCase):
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1). Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
''', str(net)) ''', str(net))
net2 = Network()
tmp = net2.add_two(X="X", Y="Y")
self.assertTrue(isinstance(tmp, core.Variable))
net2.complete_add_op()
self.assertEqual(
'''Op(naive_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
''', str(net2))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册