提交 d09fd1f6 编写于 作者: C chengduoZH

test seresnext

上级 27073c28
......@@ -130,7 +130,9 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False):
class TestResnet(TestParallelExecutorBase):
def check_resnet_convergence(self, balance_parameter_opt_between_cards):
def check_resnet_convergence(self,
balance_parameter_opt_between_cards,
use_cuda=True):
import functools
batch_size = 2
self.check_network_convergence(
......@@ -138,14 +140,17 @@ class TestResnet(TestParallelExecutorBase):
SE_ResNeXt50Small, batch_size=batch_size),
iter=20,
batch_size=batch_size,
use_cuda=use_cuda,
balance_parameter_opt_between_cards=balance_parameter_opt_between_cards
)
def test_resnet(self):
self.check_resnet_convergence(False)
self.check_resnet_convergence(False, use_cuda=True)
# self.check_resnet_convergence(False,use_cuda=False)
def test_resnet_with_new_strategy(self):
self.check_resnet_convergence(True)
self.check_resnet_convergence(True, use_cuda=True)
self.check_resnet_convergence(True, use_cuda=False)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册