提交 ab86fb11 编写于 作者: J JiayiFeng

complete parallel accuracy test

上级 415460b5
...@@ -130,7 +130,8 @@ class ParallelExecutor(object): ...@@ -130,7 +130,8 @@ class ParallelExecutor(object):
or numpy array. or numpy array.
:return: fetched value list. :return: fetched value list.
""" """
feed = feed_dict if feed == {}:
feed = feed_dict
if not isinstance(feed, dict): if not isinstance(feed, dict):
raise TypeError("feed should be a dict") raise TypeError("feed should be a dict")
......
...@@ -200,17 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -200,17 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self, def check_network_convergence(self,
method, method,
memory_opt=True, memory_opt=True,
iter=10, iter=50,
batch_size=None, batch_size=None,
allow_op_delay=False, allow_op_delay=False,
feed_dict={}, feed_dict={},
seed=None, seed=None,
use_parallel_executor=True): use_parallel_executor=True):
def run_executor(exe, feed, fetch_list, program=None):
if isinstance(exe, fluid.ParallelExecutor):
res = exe.run(fetch_list=fetch_list, feed=feed)
elif isinstance(exe, fluid.Executor):
if program is None:
program = fluid.default_main_program()
res = exe.run(program=program, feed=feed, fetch_list=fetch_list)
else:
raise ValueError('Unkown type exe')
return res
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
if seed is not None: if seed is not None:
startup.random_seed = seed startup.random_seed = seed
main.random_seed = seed
loss = method(use_feed=len(feed_dict) > 0) loss = method(use_feed=len(feed_dict) > 0)
adam = fluid.optimizer.Adam() adam = fluid.optimizer.Adam()
adam.minimize(loss) adam.minimize(loss)
...@@ -229,13 +241,15 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -229,13 +241,15 @@ class TestParallelExecutorBase(unittest.TestCase):
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count() batch_size *= fluid.core.get_cuda_device_count()
begin = time.time() begin = time.time()
first_loss, = exe.run([loss.name], feed=feed_dict) first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
first_loss = numpy.array(first_loss) first_loss = numpy.array(first_loss)
for i in xrange(iter): for i in xrange(iter):
exe.run([], feed=feed_dict) run_executor(exe=exe, feed=feed_dict, fetch_list=[])
last_loss, = exe.run([loss.name], feed=feed_dict) last_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
end = time.time() end = time.time()
if batch_size is not None: if batch_size is not None:
...@@ -277,14 +291,25 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -277,14 +291,25 @@ class TestMNIST(TestParallelExecutorBase):
"label": label}) "label": label})
def test_simple_fc_parallel_accuracy(self): def test_simple_fc_parallel_accuracy(self):
#single_first_loss, single_last_loss = self.check_network_convergence( img = numpy.zeros(shape=[32, 784], dtype='float32')
# simple_fc_net, seed=0, use_parallel_executor=False) label = numpy.ones(shape=[32, 1], dtype='int64')
#parallel_first_loss, parallel_last_loss = self.check_network_convergence( single_first_loss, single_last_loss = self.check_network_convergence(
# simple_fc_net, seed=0, use_parallel_executor=True) method=simple_fc_net,
print('single_first_loss=', single_first_loss) seed=1000,
print('single_last_loss=', single_last_loss) feed_dict={"image": img,
print('parallel_first_loss=', parallel_first_loss) "label": label},
print('parallel_last_loss=', parallel_last_loss) use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1000,
feed_dict={"image": img,
"label": label},
use_parallel_executor=True)
for p_f in parallel_first_loss:
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
for p_l in parallel_last_loss:
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册