test_multiprocess_dataloader_iterable_dataset_static.py 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import time
import unittest
18

19 20
import numpy as np

21
import paddle
22
import paddle.fluid as fluid
23
from paddle.io import DataLoader, IterableDataset
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

EPOCH_NUM = 2
BATCH_SIZE = 8
IMAGE_SIZE = 32
SAMPLE_NUM = 80
CLASS_NUM = 10


class RandomDataset(IterableDataset):
    def __init__(self, sample_num, class_num):
        self.sample_num = sample_num
        self.class_num = class_num

    def __iter__(self):
        for i in range(self.sample_num):
            np.random.seed(i)
            image = np.random.random([IMAGE_SIZE]).astype('float32')
41 42 43
            label = np.random.randint(0, self.class_num - 1, (1,)).astype(
                'int64'
            )
44 45 46 47 48 49 50 51 52 53 54
            yield image, label


def simple_fc_net_static():
    startup_prog = fluid.Program()
    main_prog = fluid.Program()
    startup_prog.random_seed = 1
    main_prog.random_seed = 1

    with fluid.unique_name.guard():
        with fluid.program_guard(main_prog, startup_prog):
55 56 57
            image = fluid.data(
                name='image', shape=[None, IMAGE_SIZE], dtype='float32'
            )
58 59
            label = fluid.data(name='label', shape=[None, 1], dtype='int64')
            hidden = image
60 61 62 63 64 65
            param_attr = fluid.ParamAttr(
                initializer=fluid.initializer.Constant(value=0.8)
            )
            bias_attr = fluid.ParamAttr(
                initializer=fluid.initializer.Constant(value=0.5)
            )
66
            for hidden_size in [10, 20, 30]:
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
                hidden = fluid.layers.fc(
                    hidden,
                    size=hidden_size,
                    act='tanh',
                    param_attr=param_attr,
                    bias_attr=bias_attr,
                )

            predict_label = fluid.layers.fc(
                hidden,
                size=CLASS_NUM,
                act='softmax',
                param_attr=param_attr,
                bias_attr=bias_attr,
            )
82
            loss = paddle.mean(
83 84
                fluid.layers.cross_entropy(input=predict_label, label=label)
            )
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

            optimizer = fluid.optimizer.Adam()
            optimizer.minimize(loss)
    return startup_prog, main_prog, image, label, loss


def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
    places = []
    if with_cpu:
        places.append([fluid.CPUPlace()])
        if with_data_parallel:
            places.append([fluid.CPUPlace()] * 2)

    if with_gpu and fluid.core.is_compiled_with_cuda():
        tmp = fluid.cuda_places()[:2]
        assert len(tmp) > 0, "no gpu detected"
K
Kaipeng Deng 已提交
101
        if with_data_parallel and len(tmp) > 1:
102 103 104 105 106 107
            places.append(tmp)
        places.append([tmp[0]])
    return places


class TestStaticDataLoader(unittest.TestCase):
K
Kaipeng Deng 已提交
108
    def run_main(self, num_workers, places, persistent_workers):
109 110 111 112 113
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            startup_prog, main_prog, image, label, loss = simple_fc_net_static()

            dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
114 115 116 117 118 119 120 121 122 123
            dataloader = DataLoader(
                dataset,
                feed_list=[image, label],
                places=places,
                num_workers=num_workers,
                batch_size=BATCH_SIZE,
                return_list=False,
                drop_last=True,
                persistent_workers=persistent_workers,
            )
124 125 126 127 128 129 130
            # assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)

            exe = fluid.Executor(place=places[0])
            exe.run(startup_prog)

            prog = fluid.CompiledProgram(main_prog)
            if len(places) > 1:
131 132 133
                prog = prog.with_data_parallel(
                    loss_name=loss.name, places=places
                )
134 135 136 137

            step_list = []
            loss_list = []
            start_t = time.time()
138
            for i in range(EPOCH_NUM):
139 140 141
                step = 0
                for d in dataloader:
                    assert len(d) == len(places), "{} != {}".format(
142 143
                        len(d), len(places)
                    )
144 145 146 147 148 149 150
                    for i, item in enumerate(d):
                        image = item['image']
                        label = item['label']
                        assert image.shape() == [BATCH_SIZE, IMAGE_SIZE]
                        assert label.shape() == [BATCH_SIZE, 1]
                        assert image._place()._equals(places[i])
                        assert label._place()._equals(places[i])
151 152 153 154 155 156
                    (L,) = exe.run(
                        program=prog,
                        feed=d,
                        fetch_list=[loss],
                        use_program_cache=True,
                    )
157 158 159 160 161 162 163 164
                    loss_list.append(np.mean(L))
                    step += 1
                step_list.append(step)

        end_t = time.time()
        ret = {
            "time": end_t - start_t,
            "step": step_list,
165
            "loss": np.array(loss_list),
166 167 168 169 170 171
        }
        print("time cost", ret['time'], 'step_list', ret['step'])
        return ret

    def test_main(self):
        for p in prepare_places(True):
K
Kaipeng Deng 已提交
172 173 174
            for persistent_workers in [False, True]:
                results = []
                for num_workers in [0, 2]:
175 176 177 178 179 180
                    print(
                        self.__class__.__name__,
                        p,
                        num_workers,
                        persistent_workers,
                    )
K
Kaipeng Deng 已提交
181
                    sys.stdout.flush()
182 183 184 185 186
                    ret = self.run_main(
                        num_workers=num_workers,
                        places=p,
                        persistent_workers=persistent_workers,
                    )
K
Kaipeng Deng 已提交
187
                    results.append(ret)
188 189 190 191
                assert (
                    results[0]['loss'].shape[0] * 2
                    == results[1]['loss'].shape[0]
                )
192 193


194 195 196 197 198 199 200 201 202 203 204 205
class RandomBatchedDataset(IterableDataset):
    def __init__(self, sample_num, class_num):
        self.sample_num = sample_num // BATCH_SIZE
        self.class_num = class_num

    def __iter__(self):
        for i in range(self.sample_num):
            np.random.seed(i)
            images = []
            labels = []
            for _ in range(BATCH_SIZE):
                image = np.random.random([IMAGE_SIZE]).astype('float32')
206 207 208
                label = np.random.randint(0, self.class_num - 1, (1,)).astype(
                    'int64'
                )
209 210 211 212 213 214
                images.append(image)
                labels.append(label)
            yield np.stack(images, axis=0), np.stack(labels, axis=0)


class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
K
Kaipeng Deng 已提交
215
    def run_main(self, num_workers, places, persistent_workers):
216 217 218 219 220
        scope = fluid.Scope()
        with fluid.scope_guard(scope):
            startup_prog, main_prog, image, label, loss = simple_fc_net_static()

            dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM)
221 222 223 224 225 226 227 228 229 230
            dataloader = DataLoader(
                dataset,
                feed_list=[image, label],
                places=places,
                num_workers=num_workers,
                batch_size=None,
                return_list=False,
                drop_last=True,
                persistent_workers=persistent_workers,
            )
231 232 233 234 235 236

            exe = fluid.Executor(place=places[0])
            exe.run(startup_prog)

            prog = fluid.CompiledProgram(main_prog)
            if len(places) > 1:
237 238 239
                prog = prog.with_data_parallel(
                    loss_name=loss.name, places=places
                )
240 241 242 243

            step_list = []
            loss_list = []
            start_t = time.time()
244
            for i in range(EPOCH_NUM):
245 246 247
                step = 0
                for d in dataloader:
                    assert len(d) == len(places), "{} != {}".format(
248 249
                        len(d), len(places)
                    )
250 251 252 253 254 255 256
                    for i, item in enumerate(d):
                        image = item['image']
                        label = item['label']
                        assert image.shape() == [BATCH_SIZE, IMAGE_SIZE]
                        assert label.shape() == [BATCH_SIZE, 1]
                        assert image._place()._equals(places[i])
                        assert label._place()._equals(places[i])
257 258 259 260 261 262
                    (L,) = exe.run(
                        program=prog,
                        feed=d,
                        fetch_list=[loss],
                        use_program_cache=True,
                    )
263 264 265 266 267 268 269 270
                    loss_list.append(np.mean(L))
                    step += 1
                step_list.append(step)

        end_t = time.time()
        ret = {
            "time": end_t - start_t,
            "step": step_list,
271
            "loss": np.array(loss_list),
272 273 274 275 276
        }
        print("time cost", ret['time'], 'step_list', ret['step'])
        return ret


277 278
if __name__ == '__main__':
    unittest.main()