提交 754f0c68 编写于 作者: Y Yu Yang

Fix unittest

上级 b80590d7
...@@ -57,8 +57,8 @@ class Scope { ...@@ -57,8 +57,8 @@ class Scope {
return var; return var;
} else { } else {
auto ptr = new Variable(); auto ptr = new Variable();
vars_[name] = std::unique_ptr<Variable>(ptr); name_to_var_[name] = std::unique_ptr<Variable>(ptr);
var_names_[ptr] = name; var_to_name_[ptr] = name;
return GetVariable(name); return GetVariable(name);
} }
} }
...@@ -70,8 +70,8 @@ class Scope { ...@@ -70,8 +70,8 @@ class Scope {
* from it's parent scope. Return nullptr if not found. * from it's parent scope. Return nullptr if not found.
*/ */
Variable* GetVariable(const std::string& name) const { Variable* GetVariable(const std::string& name) const {
auto it = vars_.find(name); auto it = name_to_var_.find(name);
if (it != vars_.end()) { if (it != name_to_var_.end()) {
return it->second.get(); return it->second.get();
} else if (parent_ != nullptr) { } else if (parent_ != nullptr) {
return parent_->GetVariable(name); return parent_->GetVariable(name);
...@@ -86,21 +86,21 @@ class Scope { ...@@ -86,21 +86,21 @@ class Scope {
* Find if there is a Variable in this scope and it's parent scope * Find if there is a Variable in this scope and it's parent scope
*/ */
bool HasVariable(const std::string& name) const { bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() || return (name_to_var_.find(name) != name_to_var_.end() ||
(parent_ && parent_->HasVariable(name))); (parent_ && parent_->HasVariable(name)));
} }
std::string GetVariableName(Variable* const var) const { std::string GetVariableName(Variable* const var) const {
try { try {
return var_names_.at(var); return var_to_name_.at(var);
} catch (...) { } catch (...) {
return ""; return "";
} }
} }
private: private:
std::unordered_map<Variable*, std::string> var_names_; std::unordered_map<Variable*, std::string> var_to_name_;
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> name_to_var_;
std::shared_ptr<Scope> parent_{nullptr}; std::shared_ptr<Scope> parent_{nullptr};
}; };
......
...@@ -42,6 +42,9 @@ TEST(Scope, Create) { ...@@ -42,6 +42,9 @@ TEST(Scope, Create) {
EXPECT_EQ(var4, var2); EXPECT_EQ(var4, var2);
EXPECT_EQ("a", scope->GetVariableName(var4)); EXPECT_EQ("a", scope->GetVariableName(var4));
Scope scope2;
auto var = scope2.CreateVariable("tmp");
EXPECT_EQ("", scope->GetVariableName(var));
} }
TEST(Scope, Parent) { TEST(Scope, Parent) {
......
...@@ -15,14 +15,6 @@ limitations under the License. */ ...@@ -15,14 +15,6 @@ limitations under the License. */
#include <Python.h> #include <Python.h>
#include <fstream> #include <fstream>
#include <vector> #include <vector>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "paddle/framework/net.h" #include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
...@@ -160,7 +152,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -160,7 +152,7 @@ All parameter, weight, gradient are variables in Paddle.
net.def_static("create", net.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> { []() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>(); auto retv = std::make_shared<pd::PlainNet>();
retv->type_ = "naive_net"; retv->type_ = "plain_net";
return retv; return retv;
}) })
.def("add_op", &pd::PlainNet::AddOp) .def("add_op", &pd::PlainNet::AddOp)
......
...@@ -11,7 +11,7 @@ class TestNet(unittest.TestCase): ...@@ -11,7 +11,7 @@ class TestNet(unittest.TestCase):
net.complete_add_op() net.complete_add_op()
self.assertTrue(isinstance(fc_out, core.Variable)) self.assertTrue(isinstance(fc_out, core.Variable))
self.assertEqual( self.assertEqual(
'''Op(naive_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1). '''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0). Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0). Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0). Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
...@@ -23,7 +23,7 @@ class TestNet(unittest.TestCase): ...@@ -23,7 +23,7 @@ class TestNet(unittest.TestCase):
self.assertTrue(isinstance(tmp, core.Variable)) self.assertTrue(isinstance(tmp, core.Variable))
net2.complete_add_op() net2.complete_add_op()
self.assertEqual( self.assertEqual(
'''Op(naive_net), inputs:(X, Y), outputs:(add_two@OUT@2). '''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2). Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
''', str(net2)) ''', str(net2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册