提交 a980b83a 编写于 作者: Y Yu Yang

Fix RNN unittest bugs.

* The DataProvider should be INCREF every time.
上级 fefb3c13
...@@ -252,19 +252,9 @@ private: ...@@ -252,19 +252,9 @@ private:
// only for instance will make python reference-count error. // only for instance will make python reference-count error.
// //
// So here, we increase reference count manually. // So here, we increase reference count manually.
if (gModuleClsPtrs_.find((uintptr_t)module.get()) != Py_XINCREF(module.get());
gModuleClsPtrs_.end()) { Py_XINCREF(moduleDict.get());
// Multi instance use same module Py_XINCREF(cls.get());
Py_XINCREF(module.get());
Py_XINCREF(moduleDict.get());
} else {
gModuleClsPtrs_.insert((uintptr_t)module.get());
}
if (gModuleClsPtrs_.find((uintptr_t)cls.get()) != gModuleClsPtrs_.end()) {
Py_XINCREF(cls.get());
} else {
gModuleClsPtrs_.insert((uintptr_t)cls.get());
}
PyObjectPtr fileListInPy = loadPyFileLists(fileListName); PyObjectPtr fileListInPy = loadPyFileLists(fileListName);
PyDict_SetItemString(kwargs.get(), "file_list", fileListInPy.get()); PyDict_SetItemString(kwargs.get(), "file_list", fileListInPy.get());
...@@ -471,7 +461,6 @@ private: ...@@ -471,7 +461,6 @@ private:
std::vector<std::string> fileLists_; std::vector<std::string> fileLists_;
std::vector<SlotHeader> headers_; std::vector<SlotHeader> headers_;
static PyObjectPtr zeroTuple_; static PyObjectPtr zeroTuple_;
static std::unordered_set<uintptr_t> gModuleClsPtrs_;
class PositionRandom { class PositionRandom {
public: public:
...@@ -671,7 +660,6 @@ public: ...@@ -671,7 +660,6 @@ public:
} }
}; };
std::unordered_set<uintptr_t> PyDataProvider2::gModuleClsPtrs_;
PyObjectPtr PyDataProvider2::zeroTuple_(PyTuple_New(0)); PyObjectPtr PyDataProvider2::zeroTuple_(PyTuple_New(0));
REGISTER_DATA_PROVIDER_EX(py2, PyDataProvider2); REGISTER_DATA_PROVIDER_EX(py2, PyDataProvider2);
......
...@@ -127,7 +127,7 @@ TEST(RecurrentGradientMachine, HasSubSequence) { ...@@ -127,7 +127,7 @@ TEST(RecurrentGradientMachine, HasSubSequence) {
} }
} }
TEST(RecurrentGradientMachine, DISABLED_rnn) { TEST(RecurrentGradientMachine, rnn) {
for (bool useGpu : {false, true}) { for (bool useGpu : {false, true}) {
test("gserver/tests/sequence_rnn.conf", test("gserver/tests/sequence_rnn.conf",
"gserver/tests/sequence_nest_rnn.conf", "gserver/tests/sequence_nest_rnn.conf",
...@@ -136,7 +136,7 @@ TEST(RecurrentGradientMachine, DISABLED_rnn) { ...@@ -136,7 +136,7 @@ TEST(RecurrentGradientMachine, DISABLED_rnn) {
} }
} }
TEST(RecurrentGradientMachine, DISABLED_rnn_multi_input) { TEST(RecurrentGradientMachine, rnn_multi_input) {
for (bool useGpu : {false, true}) { for (bool useGpu : {false, true}) {
test("gserver/tests/sequence_rnn_multi_input.conf", test("gserver/tests/sequence_rnn_multi_input.conf",
"gserver/tests/sequence_nest_rnn_multi_input.conf", "gserver/tests/sequence_nest_rnn_multi_input.conf",
...@@ -145,7 +145,7 @@ TEST(RecurrentGradientMachine, DISABLED_rnn_multi_input) { ...@@ -145,7 +145,7 @@ TEST(RecurrentGradientMachine, DISABLED_rnn_multi_input) {
} }
} }
TEST(RecurrentGradientMachine, DISABLED_rnn_multi_unequalength_input) { TEST(RecurrentGradientMachine, rnn_multi_unequalength_input) {
for (bool useGpu : {false, true}) { for (bool useGpu : {false, true}) {
test("gserver/tests/sequence_rnn_multi_unequalength_inputs.py", test("gserver/tests/sequence_rnn_multi_unequalength_inputs.py",
"gserver/tests/sequence_nest_rnn_multi_unequalength_inputs.py", "gserver/tests/sequence_nest_rnn_multi_unequalength_inputs.py",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册