From a980b83a0cb974f0622051bea957f8305831b4bb Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 18 Dec 2016 13:53:11 +0800 Subject: [PATCH] Fix RNN unittest bugs. * The DataProvider should be INCREF every time. --- .../gserver/dataproviders/PyDataProvider2.cpp | 18 +++--------------- .../tests/test_RecurrentGradientMachine.cpp | 6 +++--- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/paddle/gserver/dataproviders/PyDataProvider2.cpp b/paddle/gserver/dataproviders/PyDataProvider2.cpp index 460efc5adc..c26e242534 100644 --- a/paddle/gserver/dataproviders/PyDataProvider2.cpp +++ b/paddle/gserver/dataproviders/PyDataProvider2.cpp @@ -252,19 +252,9 @@ private: // only for instance will make python reference-count error. // // So here, we increase reference count manually. - if (gModuleClsPtrs_.find((uintptr_t)module.get()) != - gModuleClsPtrs_.end()) { - // Multi instance use same module - Py_XINCREF(module.get()); - Py_XINCREF(moduleDict.get()); - } else { - gModuleClsPtrs_.insert((uintptr_t)module.get()); - } - if (gModuleClsPtrs_.find((uintptr_t)cls.get()) != gModuleClsPtrs_.end()) { - Py_XINCREF(cls.get()); - } else { - gModuleClsPtrs_.insert((uintptr_t)cls.get()); - } + Py_XINCREF(module.get()); + Py_XINCREF(moduleDict.get()); + Py_XINCREF(cls.get()); PyObjectPtr fileListInPy = loadPyFileLists(fileListName); PyDict_SetItemString(kwargs.get(), "file_list", fileListInPy.get()); @@ -471,7 +461,6 @@ private: std::vector fileLists_; std::vector headers_; static PyObjectPtr zeroTuple_; - static std::unordered_set gModuleClsPtrs_; class PositionRandom { public: @@ -671,7 +660,6 @@ public: } }; -std::unordered_set PyDataProvider2::gModuleClsPtrs_; PyObjectPtr PyDataProvider2::zeroTuple_(PyTuple_New(0)); REGISTER_DATA_PROVIDER_EX(py2, PyDataProvider2); diff --git a/paddle/gserver/tests/test_RecurrentGradientMachine.cpp b/paddle/gserver/tests/test_RecurrentGradientMachine.cpp index b47279b77a..e19cf35cd5 100644 --- a/paddle/gserver/tests/test_RecurrentGradientMachine.cpp +++ b/paddle/gserver/tests/test_RecurrentGradientMachine.cpp @@ -127,7 +127,7 @@ TEST(RecurrentGradientMachine, HasSubSequence) { } } -TEST(RecurrentGradientMachine, DISABLED_rnn) { +TEST(RecurrentGradientMachine, rnn) { for (bool useGpu : {false, true}) { test("gserver/tests/sequence_rnn.conf", "gserver/tests/sequence_nest_rnn.conf", @@ -136,7 +136,7 @@ TEST(RecurrentGradientMachine, DISABLED_rnn) { } } -TEST(RecurrentGradientMachine, DISABLED_rnn_multi_input) { +TEST(RecurrentGradientMachine, rnn_multi_input) { for (bool useGpu : {false, true}) { test("gserver/tests/sequence_rnn_multi_input.conf", "gserver/tests/sequence_nest_rnn_multi_input.conf", @@ -145,7 +145,7 @@ TEST(RecurrentGradientMachine, DISABLED_rnn_multi_input) { } } -TEST(RecurrentGradientMachine, DISABLED_rnn_multi_unequalength_input) { +TEST(RecurrentGradientMachine, rnn_multi_unequalength_input) { for (bool useGpu : {false, true}) { test("gserver/tests/sequence_rnn_multi_unequalength_inputs.py", "gserver/tests/sequence_nest_rnn_multi_unequalength_inputs.py", -- GitLab