提交 ab86fb11 编写于 作者: J JiayiFeng

complete parallel accuracy test

上级 415460b5
......@@ -130,7 +130,8 @@ class ParallelExecutor(object):
or numpy array.
:return: fetched value list.
"""
feed = feed_dict
if feed == {}:
feed = feed_dict
if not isinstance(feed, dict):
raise TypeError("feed should be a dict")
......
......@@ -200,17 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self,
method,
memory_opt=True,
iter=10,
iter=50,
batch_size=None,
allow_op_delay=False,
feed_dict={},
seed=None,
use_parallel_executor=True):
def run_executor(exe, feed, fetch_list, program=None):
if isinstance(exe, fluid.ParallelExecutor):
res = exe.run(fetch_list=fetch_list, feed=feed)
elif isinstance(exe, fluid.Executor):
if program is None:
program = fluid.default_main_program()
res = exe.run(program=program, feed=feed, fetch_list=fetch_list)
else:
raise ValueError('Unkown type exe')
return res
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
if seed is not None:
startup.random_seed = seed
main.random_seed = seed
loss = method(use_feed=len(feed_dict) > 0)
adam = fluid.optimizer.Adam()
adam.minimize(loss)
......@@ -229,13 +241,15 @@ class TestParallelExecutorBase(unittest.TestCase):
if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count()
begin = time.time()
first_loss, = exe.run([loss.name], feed=feed_dict)
first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
first_loss = numpy.array(first_loss)
for i in xrange(iter):
exe.run([], feed=feed_dict)
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
last_loss, = exe.run([loss.name], feed=feed_dict)
last_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
end = time.time()
if batch_size is not None:
......@@ -277,14 +291,25 @@ class TestMNIST(TestParallelExecutorBase):
"label": label})
def test_simple_fc_parallel_accuracy(self):
#single_first_loss, single_last_loss = self.check_network_convergence(
# simple_fc_net, seed=0, use_parallel_executor=False)
#parallel_first_loss, parallel_last_loss = self.check_network_convergence(
# simple_fc_net, seed=0, use_parallel_executor=True)
print('single_first_loss=', single_first_loss)
print('single_last_loss=', single_last_loss)
print('parallel_first_loss=', parallel_first_loss)
print('parallel_last_loss=', parallel_last_loss)
img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1000,
feed_dict={"image": img,
"label": label},
use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1000,
feed_dict={"image": img,
"label": label},
use_parallel_executor=True)
for p_f in parallel_first_loss:
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
for p_l in parallel_last_loss:
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_batchnorm_fc(self):
self.check_network_convergence(fc_with_batchnorm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册