提交 0ceeacbe 编写于 作者: Y Yu Yang

Make Scope can lookup variable name by variable

* Refine unittest also
上级 0ab678e9
......@@ -56,7 +56,9 @@ class Scope {
if (var) {
return var;
} else {
vars_[name] = std::unique_ptr<Variable>(new Variable());
auto ptr = new Variable();
vars_[name] = std::unique_ptr<Variable>(ptr);
var_names_[ptr] = name;
return GetVariable(name);
}
}
......@@ -88,7 +90,16 @@ class Scope {
(parent_ && parent_->HasVariable(name)));
}
std::string GetVariableName(Variable* const var) const {
try {
return var_names_.at(var);
} catch (...) {
return "";
}
}
private:
std::unordered_map<Variable*, std::string> var_names_;
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::shared_ptr<Scope> parent_{nullptr};
};
......
......@@ -40,6 +40,8 @@ TEST(Scope, Create) {
/// already exist.
Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2);
EXPECT_EQ("a", scope->GetVariableName(var4));
}
TEST(Scope, Parent) {
......
......@@ -56,6 +56,11 @@ void ExposeOperator(ClassType& m) {
.def("__str__", &ClassType::type::DebugString);
}
static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
}
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");
......@@ -106,7 +111,8 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference)
.def("create_var",
&pd::Scope::CreateVariable,
py::return_value_policy::reference);
py::return_value_policy::reference)
.def("get_var_name", &pd::Scope::GetVariableName);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
......@@ -166,5 +172,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator);
return m.ptr();
}
......@@ -29,35 +29,31 @@ class NetworkFunctor(object):
if ipt in kwargs:
var = kwargs[ipt]
if isinstance(var, basestring):
var_name = var
var = create_var(var)
self.net.var_name_map[var] = var_name
if not isinstance(var, core.Variable):
raise TypeError(
"Input of op creation must be string or variable")
kwargs[ipt] = self.net.var_name_map[var]
kwargs[ipt] = get_cur_scope().get_var_name(var)
notemp_outputs = self.func.all_not_temp_output_args
for name in notemp_outputs:
if name not in kwargs:
kwargs[
name] = self.func.__name__ + "@OUT@%d" % self.net.generate_idx
self.net.generate_idx += 1
name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
)
outputs = self.func.all_output_args
for opt in outputs:
if opt in kwargs:
var = kwargs[opt]
if isinstance(var, basestring):
var_name = var
var = create_var(var)
self.net.var_name_map[var] = var_name
if not isinstance(var, core.Variable):
raise TypeError(
"Output of op creation must be string or variable")
kwargs[opt] = self.net.var_name_map[var]
kwargs[opt] = get_cur_scope().get_var_name(var)
op = self.func(**kwargs)
......@@ -93,8 +89,6 @@ class Network(object):
self.net = core.Net.create()
funcs = (func_name for func_name in dir(op_creations)
if not func_name.startswith("__"))
self.generate_idx = 0
self.var_name_map = dict()
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
......
......@@ -18,6 +18,15 @@ class TestNet(unittest.TestCase):
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
''', str(net))
net2 = Network()
tmp = net2.add_two(X="X", Y="Y")
self.assertTrue(isinstance(tmp, core.Variable))
net2.complete_add_op()
self.assertEqual(
'''Op(naive_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
''', str(net2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册