test_parallel_op.py 5.0 KB
Newer Older
Y
Yang Yang 已提交
1 2
import unittest
import paddle.v2.fluid as fluid
Y
Yang Yu 已提交
3 4 5 6
import numpy


class BaseParallelForTest(unittest.TestCase):
Y
Yang Yu 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
    def run_test(self, callback, feed, fetch):
        """
        Run the unittest for parallel.for
        Args:
            callback(callable): A callable function returns a generator. There 
                are two yields in the generator function. The first yield 
                returns the data layers, and the second yield returns the loss. 
                The modified data variables will be sent back during the first 
                yield.
            
            feed(dict): The executor feeding dictionary.
            fetch(list|basestr): The fetch name lists. 

        Returns:
            None
            
        Raises:
            AssertionError when the computation of cpu, parallel.for in cpu, 
                gpu, parallel.for in gpu are different.

        """
Y
Yang Yu 已提交
28
        cpu = fluid.CPUPlace()
Y
Yang Yu 已提交
29
        result_cpu = self._run_test_impl_(
Y
Yang Yu 已提交
30 31 32 33 34
            callback=callback,
            feed=feed,
            fetch=fetch,
            place=cpu,
            use_parallel=False)
Y
Yang Yu 已提交
35
        result_cpu_parallel = self._run_test_impl_(
Y
Yang Yu 已提交
36 37 38 39 40 41 42
            callback=callback,
            feed=feed,
            fetch=fetch,
            place=cpu,
            use_parallel=True)
        if fluid.core.is_compile_gpu():
            gpu = fluid.CUDAPlace(0)
Y
Yang Yu 已提交
43
            result_gpu = self._run_test_impl_(
Y
Yang Yu 已提交
44 45 46 47 48
                callback=callback,
                feed=feed,
                fetch=fetch,
                place=gpu,
                use_parallel=False)
Y
Yang Yu 已提交
49
            result_gpu_parallel = self._run_test_impl_(
Y
Yang Yu 已提交
50 51 52 53 54 55 56 57 58
                callback=callback,
                feed=feed,
                fetch=fetch,
                place=gpu,
                use_parallel=True)
            self._assert_same_(fetch, result_cpu, result_cpu_parallel,
                               result_gpu, result_gpu_parallel)
        else:
            self._assert_same_(fetch, result_cpu, result_cpu_parallel)
Y
Yang Yu 已提交
59

Y
Yang Yu 已提交
60 61 62 63 64 65 66 67 68 69 70
    def _run_test_impl_(self, callback, feed, fetch, place, use_parallel=False):
        """
        Run a single test, returns the fetch values
        Args:
            place(Place): the computation place. 
            use_parallel(bool): Whether use parallel.for or not. 

        Returns:
            Fetched numpy arrays.

        """
Y
Yang Yu 已提交
71 72
        if isinstance(fetch, basestring):
            fetch = [fetch]
Y
Yang Yu 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        main = fluid.Program()
        startup = fluid.Program()
        # Fix seed
        main.random_seed = 10
        startup.random_seed = 10

        with fluid.program_guard(main, startup):
            generator = callback()
            # Automatically insert parallel do if use_parallel = True
            if use_parallel:
                places = fluid.layers.get_places()
                pd = fluid.layers.ParallelDo(places)
                data = next(generator)

                if isinstance(data, fluid.Variable):
                    data = [data]
Y
Yang Yu 已提交
89

Y
Yang Yu 已提交
90 91 92 93
                with pd.do():
                    ins = map(pd.read_input, data)
                    if len(ins) == 1:
                        ins = ins[0]
Y
Yang Yu 已提交
94
                    loss = generator.send(ins)  # patch input
Y
Yang Yu 已提交
95 96 97 98 99
                    pd.write_output(loss)

                loss = pd()
            else:
                data = next(generator)
Y
Yang Yu 已提交
100 101
                loss = generator.send(data)
            self.assertIsNotNone(loss)
Y
Yang Yu 已提交
102 103 104 105 106 107 108
            avg_loss = fluid.layers.mean(x=loss)
            fluid.backward.append_backward(loss=avg_loss)

        exe = fluid.Executor(place)
        exe.run(startup)
        return exe.run(main, feed=feed, fetch_list=fetch)

Y
Yang Yu 已提交
109
    def _assert_same_(self, fetch, *args):
Y
Yang Yu 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123
        """
        Assert the return values of `run_test` are same.
        Args:
            fetch: Fetch list. Used for print error message
            *args: The fetch result lists of each situations.

        Returns:
            None
            
        Raises:
            AssertionError

        """

Y
Yang Yu 已提交
124 125 126 127 128 129 130 131 132 133 134
        def _impl_(a, b, fetch_id, item_id):
            item_str = ['CPU', 'ParallelCPU', 'GPU', 'ParallelGPU']
            flag = numpy.allclose(a, b, rtol=0.1)
            self.assertTrue(flag, "The {0} are different in {1}".format(
                fetch[fetch_id], item_str[item_id]))

        for i, items in enumerate(zip(*args)):
            self.assertGreater(len(items), 0)
            for j in range(1, len(items)):
                _impl_(items[0], items[j], fetch_id=i, item_id=j)

Y
Yang Yu 已提交
135 136 137 138 139

class ParallelOpTest(BaseParallelForTest):
    def test_simple_fc(self):
        def __network__():
            x = fluid.layers.data(shape=[784], dtype='float32', name='img')
Y
Yang Yu 已提交
140 141
            # FIXME: This is a bug of parallel.do
            x.stop_gradient = False
Y
Yang Yu 已提交
142 143 144 145 146
            x = yield x
            hidden = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
            loss = fluid.layers.mean(x=hidden)
            yield loss

Y
Yang Yu 已提交
147
        self.run_test(
Y
Yang Yu 已提交
148 149 150 151 152
            callback=__network__,
            feed={
                'img': numpy.random.random(size=(128, 784)).astype('float32')
            },
            fetch='fc1.w@GRAD')
Y
Yang Yang 已提交
153 154 155 156


if __name__ == '__main__':
    unittest.main()