未验证 提交 c083ee70 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9950 from JiayiFeng/add_parallel_executor_tests

Add parallel executor tests
...@@ -16,6 +16,7 @@ import core ...@@ -16,6 +16,7 @@ import core
import multiprocessing import multiprocessing
import framework import framework
import executor import executor
import warnings
import sys import sys
__all__ = ['ParallelExecutor'] __all__ = ['ParallelExecutor']
...@@ -62,8 +63,8 @@ class ParallelExecutor(object): ...@@ -62,8 +63,8 @@ class ParallelExecutor(object):
main_program=test_program, main_program=test_program,
share_vars_from=train_exe) share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict) train_loss, = train_exe.run([loss.name], feed=feed_dict)
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict) test_loss, = test_exe.run([loss.name], feed=feed_dict)
""" """
self._places = [] self._places = []
...@@ -103,8 +104,8 @@ class ParallelExecutor(object): ...@@ -103,8 +104,8 @@ class ParallelExecutor(object):
self.persistable_vars = [ self.persistable_vars = [
v.name v.name
for v in filter(lambda var: \ for v in filter(
var.persistable and var.type != core.VarDesc.VarType.RAW, lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW,
main.list_vars()) main.list_vars())
] ]
...@@ -163,7 +164,7 @@ class ParallelExecutor(object): ...@@ -163,7 +164,7 @@ class ParallelExecutor(object):
Returns: fetched result list. Returns: fetched result list.
""" """
if feed is None: if feed is None and feed_dict is not None:
feed = feed_dict feed = feed_dict
print >> sys.stderr, "`feed_dict` is deprecated. Please use `feed=`" print >> sys.stderr, "`feed_dict` is deprecated. Please use `feed=`"
......
...@@ -200,14 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -200,14 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self, def check_network_convergence(self,
method, method,
memory_opt=True, memory_opt=True,
iter=10, iter=50,
batch_size=None, batch_size=None,
allow_op_delay=False, allow_op_delay=False,
feed_dict=None): feed_dict=None,
seed=None,
use_parallel_executor=True):
def run_executor(exe, feed, fetch_list, program=None):
if isinstance(exe, fluid.ParallelExecutor):
res = exe.run(fetch_list=fetch_list, feed=feed)
elif isinstance(exe, fluid.Executor):
if program is None:
program = fluid.default_main_program()
res = exe.run(program=program, feed=feed, fetch_list=fetch_list)
else:
raise ValueError('Unkown type exe')
return res
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
startup.random_seed = 1 # Fix random seed startup.random_seed = 1 # Fix random seed
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
if seed is not None:
startup.random_seed = seed
loss = method(use_feed=feed_dict is not None) loss = method(use_feed=feed_dict is not None)
adam = fluid.optimizer.Adam() adam = fluid.optimizer.Adam()
adam.minimize(loss) adam.minimize(loss)
...@@ -217,18 +232,24 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -217,18 +232,24 @@ class TestParallelExecutorBase(unittest.TestCase):
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
startup_exe.run(startup) startup_exe.run(startup)
exe = fluid.ParallelExecutor( if use_parallel_executor:
True, loss_name=loss.name, allow_op_delay=allow_op_delay) exe = fluid.ParallelExecutor(
True, loss_name=loss.name, allow_op_delay=allow_op_delay)
else:
exe = fluid.Executor(place=place)
if batch_size is not None: if batch_size is not None:
batch_size *= fluid.core.get_cuda_device_count() batch_size *= fluid.core.get_cuda_device_count()
begin = time.time() begin = time.time()
first_loss, = exe.run([loss.name], feed=feed_dict) first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
first_loss = numpy.array(first_loss) first_loss = numpy.array(first_loss)
for i in xrange(iter): for i in xrange(iter):
exe.run([], feed=feed_dict) run_executor(exe=exe, feed=feed_dict, fetch_list=[])
last_loss, = exe.run([loss.name], feed=feed_dict) last_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
end = time.time() end = time.time()
if batch_size is not None: if batch_size is not None:
...@@ -239,6 +260,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -239,6 +260,7 @@ class TestParallelExecutorBase(unittest.TestCase):
print first_loss, last_loss print first_loss, last_loss
# self.assertGreater(first_loss[0], last_loss[0]) # self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss
class TestMNIST(TestParallelExecutorBase): class TestMNIST(TestParallelExecutorBase):
...@@ -268,6 +290,27 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -268,6 +290,27 @@ class TestMNIST(TestParallelExecutorBase):
simple_fc_net, feed_dict={"image": img, simple_fc_net, feed_dict={"image": img,
"label": label}) "label": label})
def test_simple_fc_parallel_accuracy(self):
img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1000,
feed_dict={"image": img,
"label": label},
use_parallel_executor=False)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1000,
feed_dict={"image": img,
"label": label},
use_parallel_executor=True)
for p_f in parallel_first_loss:
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
for p_l in parallel_last_loss:
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
self.check_network_convergence(fc_with_batchnorm) self.check_network_convergence(fc_with_batchnorm)
img = numpy.zeros(shape=[32, 784], dtype='float32') img = numpy.zeros(shape=[32, 784], dtype='float32')
...@@ -496,10 +539,10 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -496,10 +539,10 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
share_vars_from=train_exe) share_vars_from=train_exe)
for i in xrange(5): for i in xrange(5):
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict) test_loss, = test_exe.run([loss.name], feed=feed_dict)
test_loss = numpy.array(test_loss) test_loss = numpy.array(test_loss)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict) train_loss, = train_exe.run([loss.name], feed=feed_dict)
train_loss = numpy.array(train_loss) train_loss = numpy.array(train_loss)
self.assertTrue( self.assertTrue(
numpy.allclose( numpy.allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册