From 10d85208bf7b601f79be062c5ef6afc8e9d42715 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Wed, 22 Apr 2020 14:00:01 +0800 Subject: [PATCH] fix test_multiprocess_dataloader_base timeout. test=develop (#24053) --- .../fluid/tests/unittests/CMakeLists.txt | 6 +- .../test_multiprocess_dataloader_dynamic.py | 132 +++++++++++++++++ ...=> test_multiprocess_dataloader_static.py} | 138 ++++-------------- 3 files changed, 162 insertions(+), 114 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py rename python/paddle/fluid/tests/unittests/{test_multiprocess_dataloader_base.py => test_multiprocess_dataloader_static.py} (53%) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4ddcae9b240..d89dd1ed1cd 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -211,7 +211,8 @@ if (APPLE OR WIN32) list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_fds_clear) list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exit_func) list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler) - list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_base) + list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_static) + list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_dynamic) list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_exception) endif() @@ -383,6 +384,7 @@ if(NOT WIN32 AND NOT APPLE) set_tests_properties(test_imperative_data_loader_base PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) set_tests_properties(test_imperative_data_loader_fds_clear PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) # set_tests_properties(test_imperative_data_loader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) - set_tests_properties(test_multiprocess_dataloader_base PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) + set_tests_properties(test_multiprocess_dataloader_static PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) + set_tests_properties(test_multiprocess_dataloader_dynamic PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) set_tests_properties(test_multiprocess_dataloader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) endif() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py new file mode 100644 index 00000000000..6af273faf39 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import os +import sys +import six +import time +import unittest +import multiprocessing +import numpy as np + +import paddle.fluid as fluid +from paddle.io import Dataset, BatchSampler, DataLoader +from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.dygraph.base import to_variable + +from test_multiprocess_dataloader_static import RandomDataset, prepare_places + +EPOCH_NUM = 5 +BATCH_SIZE = 16 +IMAGE_SIZE = 784 +SAMPLE_NUM = 400 +CLASS_NUM = 10 + + +class SimpleFCNet(fluid.dygraph.Layer): + def __init__(self): + super(SimpleFCNet, self).__init__() + + param_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.8)) + bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.5)) + self._fcs = [] + in_channel = IMAGE_SIZE + for hidden_size in [10, 20, 30]: + self._fcs.append( + Linear( + in_channel, + hidden_size, + act='tanh', + param_attr=param_attr, + bias_attr=bias_attr)) + in_channel = hidden_size + self._fcs.append( + Linear( + in_channel, + CLASS_NUM, + act='softmax', + param_attr=param_attr, + bias_attr=bias_attr)) + + def forward(self, image): + out = image + for fc in self._fcs: + out = fc(out) + return out + + +class TestDygraphDataLoader(unittest.TestCase): + def run_main(self, num_workers, places): + fluid.default_startup_program().random_seed = 1 + fluid.default_main_program().random_seed = 1 + with fluid.dygraph.guard(places[0]): + fc_net = SimpleFCNet() + optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters()) + + dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + places=places, + num_workers=num_workers, + batch_size=BATCH_SIZE, + drop_last=True) + assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) + + step_list = [] + loss_list = [] + start_t = time.time() + for _ in six.moves.range(EPOCH_NUM): + step = 0 + for image, label in dataloader(): + out = fc_net(image) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.reduce_mean(loss) + avg_loss.backward() + optimizer.minimize(avg_loss) + fc_net.clear_gradients() + + loss_list.append(np.mean(avg_loss.numpy())) + step += 1 + step_list.append(step) + + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + print("time cost", ret['time'], 'step_list', ret['step']) + return ret + + def test_main(self): + # dynamic graph do not run with_data_parallel + for p in prepare_places(False): + results = [] + for num_workers in [0, 2]: + print(self.__class__.__name__, p, num_workers) + sys.stdout.flush() + ret = self.run_main(num_workers=num_workers, places=p) + results.append(ret) + diff = np.max( + np.abs(results[0]['loss'] - results[1]['loss']) / + np.abs(results[0]['loss'])) + self.assertLess(diff, 1e-2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_base.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py similarity index 53% rename from python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_base.py rename to python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index d6b3ed710ca..2d75126ec42 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_base.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -24,8 +24,6 @@ import numpy as np import paddle.fluid as fluid from paddle.io import Dataset, BatchSampler, DataLoader -from paddle.fluid.dygraph.nn import Linear -from paddle.fluid.dygraph.base import to_variable EPOCH_NUM = 5 BATCH_SIZE = 16 @@ -86,42 +84,24 @@ def simple_fc_net_static(): return startup_prog, main_prog, image, label, loss -class SimpleFCNet(fluid.dygraph.Layer): - def __init__(self): - super(SimpleFCNet, self).__init__() +def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True): + places = [] + if with_cpu: + places.append([fluid.CPUPlace()]) + if with_data_parallel: + places.append([fluid.CPUPlace()] * 2) - param_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.8)) - bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.5)) - self._fcs = [] - in_channel = IMAGE_SIZE - for hidden_size in [10, 20, 30]: - self._fcs.append( - Linear( - in_channel, - hidden_size, - act='tanh', - param_attr=param_attr, - bias_attr=bias_attr)) - in_channel = hidden_size - self._fcs.append( - Linear( - in_channel, - CLASS_NUM, - act='softmax', - param_attr=param_attr, - bias_attr=bias_attr)) - - def forward(self, image): - out = image - for fc in self._fcs: - out = fc(out) - return out + if with_gpu and fluid.core.is_compiled_with_cuda(): + tmp = fluid.cuda_places()[:2] + assert len(tmp) > 0, "no gpu detected" + if with_data_parallel: + places.append(tmp) + places.append([tmp[0]]) + return places class TestStaticDataLoader(unittest.TestCase): - def run_main(self, num_workers, places, with_data_parallel): + def run_main(self, num_workers, places): scope = fluid.Scope() with fluid.scope_guard(scope): startup_prog, main_prog, image, label, loss = simple_fc_net_static() @@ -140,7 +120,7 @@ class TestStaticDataLoader(unittest.TestCase): exe.run(startup_prog) prog = fluid.CompiledProgram(main_prog) - if with_data_parallel: + if len(places) > 1: prog = prog.with_data_parallel( loss_name=loss.name, places=places) @@ -176,84 +156,18 @@ class TestStaticDataLoader(unittest.TestCase): print("time cost", ret['time'], 'step_list', ret['step']) return ret - def prepare_places(self, with_data_parallel, with_cpu=True, with_gpu=True): - places = [] - # FIXME: PR_CI_Py35 may hang on Multi-CPUs with multiprocess, but it - # works fine locally, this should be fixed. OTOH, multiprocessing - # is not recommended when running on CPU generally - if with_cpu and not sys.version.startswith('3.5'): - places.append([fluid.CPUPlace()]) - if with_data_parallel: - places.append([fluid.CPUPlace()] * 2) - - if with_gpu and fluid.core.is_compiled_with_cuda(): - tmp = fluid.cuda_places()[:2] - assert len(tmp) > 0, "no gpu detected" - if with_data_parallel: - places.append(tmp) - places.append([tmp[0]]) - return places - def test_main(self): - for with_data_parallel in [False] if self.__class__.__name__ \ - == "TestDygraphDataLoader" else [True, False]: - for p in self.prepare_places(with_data_parallel): - results = [] - for num_workers in [0, 2]: - print(self.__class__.__name__, p, num_workers) - ret = self.run_main( - num_workers=num_workers, - places=p, - with_data_parallel=with_data_parallel) - results.append(ret) - diff = np.max( - np.abs(results[0]['loss'] - results[1]['loss']) / - np.abs(results[0]['loss'])) - self.assertLess(diff, 1e-2) - - -class TestDygraphDataLoader(TestStaticDataLoader): - def run_main(self, num_workers, places, with_data_parallel): - fluid.default_startup_program().random_seed = 1 - fluid.default_main_program().random_seed = 1 - with fluid.dygraph.guard(places[0]): - fc_net = SimpleFCNet() - optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters()) - - dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) - dataloader = DataLoader( - dataset, - places=places, - num_workers=num_workers, - batch_size=BATCH_SIZE, - drop_last=True) - assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) - - step_list = [] - loss_list = [] - start_t = time.time() - for _ in six.moves.range(EPOCH_NUM): - step = 0 - for image, label in dataloader(): - out = fc_net(image) - loss = fluid.layers.cross_entropy(out, label) - avg_loss = fluid.layers.reduce_mean(loss) - avg_loss.backward() - optimizer.minimize(avg_loss) - fc_net.clear_gradients() - - loss_list.append(np.mean(avg_loss.numpy())) - step += 1 - step_list.append(step) - - end_t = time.time() - ret = { - "time": end_t - start_t, - "step": step_list, - "loss": np.array(loss_list) - } - print("time cost", ret['time'], 'step_list', ret['step']) - return ret + for p in prepare_places(True): + results = [] + for num_workers in [0, 2]: + print(self.__class__.__name__, p, num_workers) + sys.stdout.flush() + ret = self.run_main(num_workers=num_workers, places=p) + results.append(ret) + diff = np.max( + np.abs(results[0]['loss'] - results[1]['loss']) / + np.abs(results[0]['loss'])) + self.assertLess(diff, 1e-2) if __name__ == '__main__': -- GitLab