提交 12aca860 编写于 作者: Y Yang Yu

Add comment

上级 2412f2f4
......@@ -4,15 +4,35 @@ import numpy
class BaseParallelForTest(unittest.TestCase):
def main(self, callback, feed, fetch):
def run_test(self, callback, feed, fetch):
"""
Run the unittest for parallel.for
Args:
callback(callable): A callable function returns a generator. There
are two yields in the generator function. The first yield
returns the data layers, and the second yield returns the loss.
The modified data variables will be sent back during the first
yield.
feed(dict): The executor feeding dictionary.
fetch(list|basestr): The fetch name lists.
Returns:
None
Raises:
AssertionError when the computation of cpu, parallel.for in cpu,
gpu, parallel.for in gpu are different.
"""
cpu = fluid.CPUPlace()
result_cpu = self._main_impl_(
result_cpu = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=cpu,
use_parallel=False)
result_cpu_parallel = self._main_impl_(
result_cpu_parallel = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
......@@ -20,13 +40,13 @@ class BaseParallelForTest(unittest.TestCase):
use_parallel=True)
if fluid.core.is_compile_gpu():
gpu = fluid.CUDAPlace(0)
result_gpu = self._main_impl_(
result_gpu = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=gpu,
use_parallel=False)
result_gpu_parallel = self._main_impl_(
result_gpu_parallel = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
......@@ -37,7 +57,17 @@ class BaseParallelForTest(unittest.TestCase):
else:
self._assert_same_(fetch, result_cpu, result_cpu_parallel)
def _main_impl_(self, callback, feed, fetch, place, use_parallel=False):
def _run_test_impl_(self, callback, feed, fetch, place, use_parallel=False):
"""
Run a single test, returns the fetch values
Args:
place(Place): the computation place.
use_parallel(bool): Whether use parallel.for or not.
Returns:
Fetched numpy arrays.
"""
if isinstance(fetch, basestring):
fetch = [fetch]
main = fluid.Program()
......@@ -77,6 +107,20 @@ class BaseParallelForTest(unittest.TestCase):
return exe.run(main, feed=feed, fetch_list=fetch)
def _assert_same_(self, fetch, *args):
"""
Assert the return values of `run_test` are same.
Args:
fetch: Fetch list. Used for print error message
*args: The fetch result lists of each situations.
Returns:
None
Raises:
AssertionError
"""
def _impl_(a, b, fetch_id, item_id):
item_str = ['CPU', 'ParallelCPU', 'GPU', 'ParallelGPU']
flag = numpy.allclose(a, b, rtol=0.1)
......@@ -100,7 +144,7 @@ class ParallelOpTest(BaseParallelForTest):
loss = fluid.layers.mean(x=hidden)
yield loss
self.main(
self.run_test(
callback=__network__,
feed={
'img': numpy.random.random(size=(128, 784)).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册