From d09fd1f6f0b87a3c188859d951d9c66be73c8190 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 10 Jun 2018 18:56:40 +0800 Subject: [PATCH] test seresnext --- .../unittests/test_parallel_executor_seresnext.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py index a3fa140cb..ef6f3b99b 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py @@ -130,7 +130,9 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False): class TestResnet(TestParallelExecutorBase): - def check_resnet_convergence(self, balance_parameter_opt_between_cards): + def check_resnet_convergence(self, + balance_parameter_opt_between_cards, + use_cuda=True): import functools batch_size = 2 self.check_network_convergence( @@ -138,14 +140,17 @@ class TestResnet(TestParallelExecutorBase): SE_ResNeXt50Small, batch_size=batch_size), iter=20, batch_size=batch_size, + use_cuda=use_cuda, balance_parameter_opt_between_cards=balance_parameter_opt_between_cards ) def test_resnet(self): - self.check_resnet_convergence(False) + self.check_resnet_convergence(False, use_cuda=True) + # self.check_resnet_convergence(False,use_cuda=False) def test_resnet_with_new_strategy(self): - self.check_resnet_convergence(True) + self.check_resnet_convergence(True, use_cuda=True) + self.check_resnet_convergence(True, use_cuda=False) if __name__ == '__main__': -- GitLab