提交 d09fd1f6 编写于 作者: C chengduoZH

test seresnext

上级 27073c28
...@@ -130,7 +130,9 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False): ...@@ -130,7 +130,9 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False):
class TestResnet(TestParallelExecutorBase): class TestResnet(TestParallelExecutorBase):
def check_resnet_convergence(self, balance_parameter_opt_between_cards): def check_resnet_convergence(self,
balance_parameter_opt_between_cards,
use_cuda=True):
import functools import functools
batch_size = 2 batch_size = 2
self.check_network_convergence( self.check_network_convergence(
...@@ -138,14 +140,17 @@ class TestResnet(TestParallelExecutorBase): ...@@ -138,14 +140,17 @@ class TestResnet(TestParallelExecutorBase):
SE_ResNeXt50Small, batch_size=batch_size), SE_ResNeXt50Small, batch_size=batch_size),
iter=20, iter=20,
batch_size=batch_size, batch_size=batch_size,
use_cuda=use_cuda,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
) )
def test_resnet(self): def test_resnet(self):
self.check_resnet_convergence(False) self.check_resnet_convergence(False, use_cuda=True)
# self.check_resnet_convergence(False,use_cuda=False)
def test_resnet_with_new_strategy(self): def test_resnet_with_new_strategy(self):
self.check_resnet_convergence(True) self.check_resnet_convergence(True, use_cuda=True)
self.check_resnet_convergence(True, use_cuda=False)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册