提交 5fc83267 编写于 作者: F fengjiayi

Add parallel accuracy test

上级 494c262a
...@@ -203,10 +203,14 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -203,10 +203,14 @@ class TestParallelExecutorBase(unittest.TestCase):
iter=10, iter=10,
batch_size=None, batch_size=None,
allow_op_delay=False, allow_op_delay=False,
feed_dict={}): feed_dict={},
random_seed=None,
use_parallel_executor=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
if seed is not None:
startup.random_seed(random_seed)
loss = method(use_feed=len(feed_dict) > 0) loss = method(use_feed=len(feed_dict) > 0)
adam = fluid.optimizer.Adam() adam = fluid.optimizer.Adam()
adam.minimize(loss) adam.minimize(loss)
...@@ -217,7 +221,11 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -217,7 +221,11 @@ class TestParallelExecutorBase(unittest.TestCase):
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
startup_exe.run(startup) startup_exe.run(startup)
if use_parallel_executor:
exe = fluid.ParallelExecutor(True, loss_name=loss.name) exe = fluid.ParallelExecutor(True, loss_name=loss.name)
else:
exe = fluid.Executor(place=place)
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count() batch_size *= fluid.core.get_cuda_device_count()
begin = time.time() begin = time.time()
...@@ -238,6 +246,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -238,6 +246,7 @@ class TestParallelExecutorBase(unittest.TestCase):
print first_loss, last_loss print first_loss, last_loss
# self.assertGreater(first_loss[0], last_loss[0]) # self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
class TestMNIST(TestParallelExecutorBase): class TestMNIST(TestParallelExecutorBase):
...@@ -267,6 +276,17 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -267,6 +276,17 @@ class TestMNIST(TestParallelExecutorBase):
simple_fc_net, feed_dict={"image": img, simple_fc_net, feed_dict={"image": img,
"label": label}) "label": label})
def test_simple_fc_parallel_accuracy(self):
single_first_loss, single_last_loss = self.check_network_convergence(
simple_fc_net, random_seed=0, use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
simple_fc_net, random_seed=0, use_parallel_executor=True)
print("FUCK")
print('single_first_loss=', single_first_loss)
print('single_last_loss=', single_last_loss)
print('parallel_first_loss=', parallel_first_loss)
print('parallel_last_loss=', parallel_last_loss)
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm)
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册