提交 b3df1f4a 编写于 作者: D Dong Zhihong

"fix tests"

上级 2434b8f5
...@@ -151,10 +151,10 @@ void BindBlockDesc(py::module &m) { ...@@ -151,10 +151,10 @@ void BindBlockDesc(py::module &m) {
return self.Var(name); return self.Var(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("var", .def("find_var",
[](BlockDescBind &self, py::bytes byte_name) { [](BlockDescBind &self, py::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
return self.Var(name); return self.FindVar(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("all_vars", &BlockDescBind::AllVars, .def("all_vars", &BlockDescBind::AllVars,
......
...@@ -21,7 +21,7 @@ class Variable(object): ...@@ -21,7 +21,7 @@ class Variable(object):
if name is None: if name is None:
name = Variable._unique_var_name_() name = Variable._unique_var_name_()
try: try:
self.desc = self.block.desc.var(name) self.desc = self.block.desc.find_var(name)
is_new_var = False is_new_var = False
except core.EnforceNotMet: except core.EnforceNotMet:
self.desc = self.block.desc.var(name) self.desc = self.block.desc.var(name)
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
def create_tensor(scope, name, shape, np_data): def create_tensor(scope, name, shape, np_data):
tensor = scope.new_var(name).get_tensor() tensor = scope.var(name).get_tensor()
tensor.set_dims(shape) tensor.set_dims(shape)
tensor.set(np_data, core.CPUPlace()) tensor.set(np_data, core.CPUPlace())
return tensor return tensor
...@@ -72,8 +72,8 @@ class DynamicRecurrentOpTest(unittest.TestCase): ...@@ -72,8 +72,8 @@ class DynamicRecurrentOpTest(unittest.TestCase):
create_tensor(self.scope, "U", [self.input_dim, self.input_dim], U) create_tensor(self.scope, "U", [self.input_dim, self.input_dim], U)
create_tensor(self.scope, "h_boot", [self.num_sents, self.input_dim], create_tensor(self.scope, "h_boot", [self.num_sents, self.input_dim],
h_boot) h_boot)
self.scope.new_var("step_scopes") self.scope.var("step_scopes")
self.scope.new_var("h@mem") self.scope.var("h@mem")
def create_rnn_op(self): def create_rnn_op(self):
# create RNNOp # create RNNOp
......
...@@ -122,7 +122,7 @@ class TestBlockDesc(unittest.TestCase): ...@@ -122,7 +122,7 @@ class TestBlockDesc(unittest.TestCase):
var3 = block.var("var3") var3 = block.var("var3")
all_vars = block.all_vars() all_vars = block.all_vars()
self.assertEqual(set(all_vars), set([var1, var2, var3])) self.assertEqual(set(all_vars), set([var1, var2, var3]))
var2_re = block.var("var2") var2_re = block.find_var("var2")
self.assertEqual(var2_re, var2) self.assertEqual(var2_re, var2)
def test_add_op(self): def test_add_op(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册