diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 6d6fe245d8a0d9b3a29f11171e7d945e09a4133c..c28c0809d853b6aaca2868391bfad682ded94623 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -18,6 +18,7 @@ if(NOT WITH_DISTRIBUTE) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist) LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec) LIST(REMOVE_ITEM TEST_OPS test_dist_ctr) + LIST(REMOVE_ITEM TEST_OPS test_dist_ctr_with_l2_decay) LIST(REMOVE_ITEM TEST_OPS test_dist_simnet_bow) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_batch_merge) LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification) @@ -100,7 +101,7 @@ if(WITH_DISTRIBUTE) # FIXME(typhoonzero): add these tests back # py_test_modules(test_dist_transformer MODULES test_dist_transformer) # set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) - set_tests_properties(test_dist_ctr test_dist_mnist test_dist_mnist_batch_merge test_dist_save_load test_dist_se_resnext test_dist_simnet_bow test_dist_text_classification test_dist_train test_dist_word2vec PROPERTIES RUN_SERIAL TRUE) + set_tests_properties(test_dist_ctr test_dist_ctr_with_l2_decay test_dist_mnist test_dist_mnist_batch_merge test_dist_save_load test_dist_se_resnext test_dist_simnet_bow test_dist_text_classification test_dist_train test_dist_word2vec PROPERTIES RUN_SERIAL TRUE) endif(NOT APPLE) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) endif() diff --git a/python/paddle/fluid/tests/unittests/dist_ctr.py b/python/paddle/fluid/tests/unittests/dist_ctr.py index dd97853a4c447e12b140d0f953eb407435706e66..e696ef23bdd41beeb8d52631fbfa1ee2ea30e068 100644 --- a/python/paddle/fluid/tests/unittests/dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_ctr.py @@ -30,11 +30,7 @@ fluid.default_main_program().random_seed = 1 class TestDistCTR2x2(TestDistRunnerBase): - def config(self): - self.use_l2_decay = False - def get_model(self, batch_size=2): - self.config() dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta() """ network definition """ @@ -103,7 +99,8 @@ class TestDistCTR2x2(TestDistRunnerBase): inference_program = paddle.fluid.default_main_program().clone() regularization = None - if self.use_l2_decay: + use_l2_decay = bool(os.getenv(['USE_L2_DECAY'], 0)) + if use_l2_decay: regularization = fluid.regularizer.L2DecayRegularizer( regularization_coeff=1e-3) diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py index f6b0971c5c94f1a43e4ff7beb6b8a52856337909..390393e04f8a1ff7b994da66cf1fa104ccb61793 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -18,7 +18,6 @@ import unittest from test_dist_base import TestDistBase -# FIXME(tangwei): sum op can not handle when inputs is empty. class TestDistCTR2x2(TestDistBase): def _setup_config(self): self._sync_mode = True @@ -28,15 +27,5 @@ class TestDistCTR2x2(TestDistBase): self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) -class TestDistCTR2x2WithL2Decay(TestDistBase): - def _setup_config(self): - self._sync_mode = True - self._enforce_place = "CPU" - - def test_dist_ctr(self): - self.check_with_place( - "dist_ctr_with_l2_decay.py", delta=1e-7, check_error_log=False) - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dist_ctr_with_l2_decay.py b/python/paddle/fluid/tests/unittests/test_dist_ctr_with_l2_decay.py similarity index 60% rename from python/paddle/fluid/tests/unittests/dist_ctr_with_l2_decay.py rename to python/paddle/fluid/tests/unittests/test_dist_ctr_with_l2_decay.py index a7fbfd644d232de65c1cfaeeb8c6a89d51e75897..558aee36536805998a630841ec163605dedfbe14 100644 --- a/python/paddle/fluid/tests/unittests/dist_ctr_with_l2_decay.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr_with_l2_decay.py @@ -11,17 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import print_function -import dist_ctr -from test_dist_base import runtime_main +import os +import unittest +from test_dist_base import TestDistBase + +class TestDistCTR2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" -class TestDistCTRWithL2Decay(dist_ctr.TestDistCTR2x2): - def config(self): - self.use_l2_decay = True + def test_dist_ctr(self): + need_envs = {"USE_L2_DECAY": "1"} + self.check_with_place( + "dist_ctr.py", + delta=1e-7, + check_error_log=False, + need_envs=need_envs) if __name__ == "__main__": - runtime_main(TestDistCTRWithL2Decay) + unittest.main()