提交 754f0c68 编写于 作者: Y Yu Yang

Fix unittest

上级 b80590d7
......@@ -57,8 +57,8 @@ class Scope {
return var;
} else {
auto ptr = new Variable();
vars_[name] = std::unique_ptr<Variable>(ptr);
var_names_[ptr] = name;
name_to_var_[name] = std::unique_ptr<Variable>(ptr);
var_to_name_[ptr] = name;
return GetVariable(name);
}
}
......@@ -70,8 +70,8 @@ class Scope {
* from it's parent scope. Return nullptr if not found.
*/
Variable* GetVariable(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) {
auto it = name_to_var_.find(name);
if (it != name_to_var_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
......@@ -86,21 +86,21 @@ class Scope {
* Find if there is a Variable in this scope and it's parent scope
*/
bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() ||
return (name_to_var_.find(name) != name_to_var_.end() ||
(parent_ && parent_->HasVariable(name)));
}
std::string GetVariableName(Variable* const var) const {
try {
return var_names_.at(var);
return var_to_name_.at(var);
} catch (...) {
return "";
}
}
private:
std::unordered_map<Variable*, std::string> var_names_;
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::unordered_map<Variable*, std::string> var_to_name_;
std::unordered_map<std::string, std::unique_ptr<Variable>> name_to_var_;
std::shared_ptr<Scope> parent_{nullptr};
};
......
......@@ -42,6 +42,9 @@ TEST(Scope, Create) {
EXPECT_EQ(var4, var2);
EXPECT_EQ("a", scope->GetVariableName(var4));
Scope scope2;
auto var = scope2.CreateVariable("tmp");
EXPECT_EQ("", scope->GetVariableName(var));
}
TEST(Scope, Parent) {
......
......@@ -15,14 +15,6 @@ limitations under the License. */
#include <Python.h>
#include <fstream>
#include <vector>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
......@@ -160,7 +152,7 @@ All parameter, weight, gradient are variables in Paddle.
net.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>();
retv->type_ = "naive_net";
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &pd::PlainNet::AddOp)
......
......@@ -11,7 +11,7 @@ class TestNet(unittest.TestCase):
net.complete_add_op()
self.assertTrue(isinstance(fc_out, core.Variable))
self.assertEqual(
'''Op(naive_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
'''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
......@@ -23,7 +23,7 @@ class TestNet(unittest.TestCase):
self.assertTrue(isinstance(tmp, core.Variable))
net2.complete_add_op()
self.assertEqual(
'''Op(naive_net), inputs:(X, Y), outputs:(add_two@OUT@2).
'''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
''', str(net2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册