diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index c16e3006d7646fa068d87d643e748f0fe14a9f5c..e264906b57f006810bb37dba8a411fa34cea0ad8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -300,7 +300,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( auto nodes = graph->ReleaseNodes(); ir::Graph &result = *graph; - // int num_trainers = Get(kNumTrainers); + int num_trainers = Get(kNumTrainers); for (auto &node : nodes) { if (node->IsVar() && node->Var()) { @@ -387,7 +387,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( } // if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) { - if (!is_forwarding && nccl_ctxs_->contexts_.size() > 1) { + // insert synchronous ops at the backpropagation; and + // insert synchronous ops if the graph contains mutilple places. + if (!is_forwarding && + (places_.size() > 1 || num_trainers > 1 || + (nccl_ctxs_ && nccl_ctxs_->contexts_.size() > 1))) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. if (static_cast(boost::get(node->Op()->GetAttr( diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index f1a07edf08843ccee42b759a2dc39517289ca853..214c2f762558bc8c194b9f1994a5b06c2054a2fc 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -49,18 +49,18 @@ FeedFetchList ParallelSSAGraphExecutor::Run( for (size_t i = 0; i < places_.size(); ++i) { auto call = [this, i, &fetch_tensors]() -> FeedFetchList { - return executors_[i]->Run(fetch_tensors); + try { + return executors_[i]->Run(fetch_tensors); + } catch (...) { + exception_holder_.Catch(std::current_exception()); + } + return FeedFetchList(); }; if (pool_) { run_futures.emplace_back(pool_->enqueue(std::move(call))); } else { - try { - fetch_datas.emplace_back(std::move(call())); - } catch (...) { - exception_holder_.Catch(std::current_exception()); - break; - } + call(); } } @@ -69,11 +69,7 @@ FeedFetchList ParallelSSAGraphExecutor::Run( if (exception_holder_.IsCaught()) { f.wait(); } else { - try { - fetch_datas.emplace_back(std::move(f.get())); - } catch (...) { - exception_holder_.Catch(std::current_exception()); - } + fetch_datas.emplace_back(std::move(f.get())); } } } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b0cd1e8e908d6522dc819241e187e7a2e7d3e17f..8d35361eb6506f4f4eca21f7ff2ea69f6dd6c1b9 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -87,7 +87,7 @@ ParallelExecutor::ParallelExecutor( "the number of places must be greater than 1."); PADDLE_ENFORCE(exec_strategy.type_ != ExecutionStrategy::kParallelGraph, "You should set build_strategy.reduce with 'AllReduce' for " - "ParallelGraph executor type"); + "the ParallelGraph executor type"); } // Step 1. Bcast the params to devs. diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 9b2a11bae12d242880829628faa089e1638424b0..517d66974433e1e7ae93bf86c8c17d85d46e7a8f 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -48,7 +48,7 @@ void MonitorThread(std::vector* thread_status, class CTRReader : public framework::FileReader { public: explicit CTRReader(const std::shared_ptr& queue, - int batch_size, int thread_num, + int batch_size, size_t thread_num, const std::vector& slots, const std::vector& file_list) : batch_size_(batch_size), slots_(slots), file_list_(file_list) { diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 86f861674c26fe61e624103c2a0d70f816a1aebc..73b8fb74fa35303de59abe0e2477edbebdff0e0c 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -26,23 +26,26 @@ import sys __all__ = ['TestParallelExecutorBase'] +ExecutorType = fluid.ExecutionStrategy().ExecutorType + class TestParallelExecutorBase(unittest.TestCase): - def check_network_convergence(self, - method, - use_cuda=True, - memory_opt=True, - iter=50, - batch_size=None, - allow_op_delay=False, - feed_dict=None, - seed=None, - use_parallel_executor=True, - use_reduce=False, - fuse_elewise_add_act_ops=False, - optimizer=fluid.optimizer.Adam, - use_fast_executor=False, - enable_sequential_execution=False): + def check_network_convergence( + self, + method, + use_cuda=True, + memory_opt=True, + iter=50, + batch_size=None, + allow_op_delay=False, + feed_dict=None, + seed=None, + use_parallel_executor=True, + use_reduce=False, + fuse_elewise_add_act_ops=False, + optimizer=fluid.optimizer.Adam, + exec_type=fluid.ExecutionStrategy().ExecutorType.Default, + enable_sequential_execution=False): def run_executor(exe, feed, fetch_list, program=None): if isinstance(exe, fluid.ParallelExecutor): res = exe.run(fetch_list=fetch_list, feed=feed) @@ -58,68 +61,69 @@ class TestParallelExecutorBase(unittest.TestCase): startup = fluid.Program() startup.random_seed = 1 # Fix random seed main.random_seed = 1 - with fluid.program_guard(main, startup): - if seed is not None: - startup.random_seed = seed - main.random_seed = seed - - loss = method(use_feed=feed_dict is not None) - - optimizer().minimize(loss) - - if memory_opt: - fluid.memory_optimize(main) - - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - startup_exe = fluid.Executor(place) - startup_exe.run(startup) - exec_strategy = fluid.ExecutionStrategy() - exec_strategy.allow_op_delay = allow_op_delay - if use_fast_executor: - exec_strategy.use_experimental_executor = True - - build_strategy = fluid.BuildStrategy() - build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ - if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce - build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops - build_strategy.enable_sequential_execution = enable_sequential_execution - if use_cuda and core.is_compiled_with_cuda(): - build_strategy.remove_unnecessary_lock = True - - if use_parallel_executor: - exe = fluid.ParallelExecutor( - use_cuda, - loss_name=loss.name, - exec_strategy=exec_strategy, - build_strategy=build_strategy) - else: - exe = fluid.Executor(place=place) - - if batch_size is not None: - batch_size *= fluid.core.get_cuda_device_count( - ) if use_cuda else int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - begin = time.time() - first_loss, = run_executor( - exe=exe, feed=feed_dict, fetch_list=[loss.name]) - - for i in range(iter): - run_executor(exe=exe, feed=feed_dict, fetch_list=[]) - - last_loss, = run_executor( - exe=exe, feed=feed_dict, fetch_list=[loss.name]) - end = time.time() - - if batch_size is not None: - print("%.4f Instance per second" % ( - (batch_size * iter + 2) / (end - begin))) - - avg_last_loss_val = np.array(last_loss).mean() - avg_first_loss_val = np.array(first_loss).mean() - if math.isnan(float(avg_last_loss_val)) or math.isnan( - float(avg_first_loss_val)): - sys.exit("got NaN loss, training failed.") - - print(first_loss, last_loss) - # self.assertGreater(first_loss[0], last_loss[0]) - return first_loss, last_loss + scope = fluid.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(main, startup): + if seed is not None: + startup.random_seed = seed + main.random_seed = seed + + loss = method(use_feed=feed_dict is not None) + + optimizer().minimize(loss) + + if memory_opt: + fluid.memory_optimize(main) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + startup_exe = fluid.Executor(place) + startup_exe.run(startup) + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.allow_op_delay = allow_op_delay + exec_strategy.executor_type = exec_type + + build_strategy = fluid.BuildStrategy() + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ + if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce + build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops + build_strategy.enable_sequential_execution = enable_sequential_execution + if use_cuda and core.is_compiled_with_cuda(): + build_strategy.remove_unnecessary_lock = True + + if use_parallel_executor: + exe = fluid.ParallelExecutor( + use_cuda, + loss_name=loss.name, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + else: + exe = fluid.Executor(place=place) + + if batch_size is not None: + batch_size *= fluid.core.get_cuda_device_count( + ) if use_cuda else int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + begin = time.time() + first_loss, = run_executor( + exe=exe, feed=feed_dict, fetch_list=[loss.name]) + + for i in range(iter): + run_executor(exe=exe, feed=feed_dict, fetch_list=[]) + + last_loss, = run_executor( + exe=exe, feed=feed_dict, fetch_list=[loss.name]) + end = time.time() + + if batch_size is not None: + print("%.4f Instance per second" % ( + (batch_size * iter + 2) / (end - begin))) + + avg_last_loss_val = np.array(last_loss).mean() + avg_first_loss_val = np.array(first_loss).mean() + if math.isnan(float(avg_last_loss_val)) or math.isnan( + float(avg_first_loss_val)): + sys.exit("got NaN loss, training failed.") + + print(first_loss, last_loss) + # self.assertGreater(first_loss[0], last_loss[0]) + return first_loss, last_loss diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py index 84b0aad8acb096a32f625e32fb640599f2882d97..d75761153c0a119b85d918b188928dbe17f4fc51 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py @@ -181,6 +181,9 @@ class TestCRFModel(unittest.TestCase): if core.is_compiled_with_cuda(): self.check_network_convergence( is_sparse=True, build_strategy=build_strategy, use_cuda=True) + self.check_network_convergence( + is_sparse=True, build_strategy=build_strategy, use_cuda=True) + self.check_network_convergence( is_sparse=True, build_strategy=build_strategy, use_cuda=False) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index 3eecc4670152e72443f731c71d7db67ca8e02e72..3dddff0d99dcb9463dbc489b88f1cdd17cc8bcf4 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -20,7 +20,7 @@ import numpy as np import paddle.fluid.core as core import os import paddle.fluid as fluid -from parallel_executor_test_base import TestParallelExecutorBase +from parallel_executor_test_base import TestParallelExecutorBase, ExecutorType def simple_fc_net(use_feed): @@ -99,7 +99,10 @@ class TestMNIST(TestParallelExecutorBase): self.assertAlmostEqual(loss[0], loss[1], delta=1e-4) # simple_fc - def check_simple_fc_convergence(self, use_cuda, use_reduce=False): + def check_simple_fc_convergence(self, + use_cuda, + use_reduce=False, + exec_type=ExecutorType.Default): if use_cuda and not core.is_compiled_with_cuda(): return @@ -110,19 +113,21 @@ class TestMNIST(TestParallelExecutorBase): feed_dict={"image": img, "label": label}, use_cuda=use_cuda, - use_reduce=use_reduce) + use_reduce=use_reduce, + exec_type=exec_type) def test_simple_fc(self): # use_cuda - self.check_simple_fc_convergence(True) + self.check_simple_fc_convergence(True, ExecutorType.Default) + self.check_simple_fc_convergence(True, ExecutorType.ParallelGraph) self.check_simple_fc_convergence(False) def test_simple_fc_with_new_strategy(self): - # use_cuda, use_reduce + # use_cuda, use_reducea self._compare_reduce_and_allreduce(simple_fc_net, True) self._compare_reduce_and_allreduce(simple_fc_net, False) - def check_simple_fc_parallel_accuracy(self, use_cuda): + def check_simple_fc_parallel_accuracy(self, use_cuda, exec_type): if use_cuda and not core.is_compiled_with_cuda(): return @@ -134,14 +139,16 @@ class TestMNIST(TestParallelExecutorBase): feed_dict={"image": img, "label": label}, use_cuda=use_cuda, - use_parallel_executor=False) + use_parallel_executor=False, + exec_type=exec_type) parallel_first_loss, parallel_last_loss = self.check_network_convergence( method=simple_fc_net, seed=1, feed_dict={"image": img, "label": label}, use_cuda=use_cuda, - use_parallel_executor=True) + use_parallel_executor=True, + exec_type=exec_type) self.assertAlmostEquals( np.mean(parallel_first_loss), @@ -151,10 +158,12 @@ class TestMNIST(TestParallelExecutorBase): np.mean(parallel_last_loss), single_last_loss, delta=1e-6) def test_simple_fc_parallel_accuracy(self): - self.check_simple_fc_parallel_accuracy(True) - self.check_simple_fc_parallel_accuracy(False) + self.check_simple_fc_parallel_accuracy(True, ExecutorType.Default) + self.check_simple_fc_parallel_accuracy(True, ExecutorType.ParallelGraph) + # FIXME(Yancey1989): ParallelGraph executor type support CPU mode + self.check_simple_fc_parallel_accuracy(False, ExecutorType.Default) - def check_batchnorm_fc_convergence(self, use_cuda, use_fast_executor): + def check_batchnorm_fc_convergence(self, use_cuda, exec_type): if use_cuda and not core.is_compiled_with_cuda(): return @@ -165,12 +174,13 @@ class TestMNIST(TestParallelExecutorBase): feed_dict={"image": img, "label": label}, use_cuda=use_cuda, - use_fast_executor=use_fast_executor) + exec_type=exec_type) def test_batchnorm_fc(self): for use_cuda in (False, True): - for use_fast_executor in (False, True): - self.check_batchnorm_fc_convergence(use_cuda, use_fast_executor) + for exec_type in (ExecutorType.Default, ExecutorType.Experimental, + ExecutorType.ParallelGraph): + self.check_batchnorm_fc_convergence(use_cuda, exec_type) def test_batchnorm_fc_with_new_strategy(self): # FIXME(zcd): close this test temporally. diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py index e7a56bb6386a812e43e5c1b5c08cd0682aa9223a..bada38894f7b23429286df314923f975801163ea 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py @@ -19,7 +19,7 @@ import paddle.fluid.layers.ops as ops from paddle.fluid.initializer import init_on_cpu from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter import paddle.fluid.core as core -from parallel_executor_test_base import TestParallelExecutorBase +from parallel_executor_test_base import TestParallelExecutorBase, ExecutorType import unittest import math import os @@ -167,13 +167,17 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120): return decayed_lr -def optimizer(learning_rate=0.01): - optimizer = fluid.optimizer.Momentum( - learning_rate=cosine_decay( - learning_rate=learning_rate, step_each_epoch=2, epochs=1), - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) - return optimizer +def optimizer(learning_rate=0.01, lr_scale=1.0): + def _opt(): + return fluid.optimizer.Momentum( + learning_rate=cosine_decay( + learning_rate=learning_rate / lr_scale, + step_each_epoch=2, + epochs=1), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + + return _opt class TestResnet(TestParallelExecutorBase): @@ -216,7 +220,7 @@ class TestResnet(TestParallelExecutorBase): batch_size=batch_size, use_cuda=use_cuda, use_reduce=False, - optimizer=optimizer) + optimizer=optimizer()) reduce_first_loss, reduce_last_loss = self.check_network_convergence( model, feed_dict={"image": img, @@ -225,7 +229,7 @@ class TestResnet(TestParallelExecutorBase): batch_size=batch_size, use_cuda=use_cuda, use_reduce=True, - optimizer=optimizer) + optimizer=optimizer()) for loss in zip(all_reduce_first_loss, reduce_first_loss): self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) @@ -243,7 +247,7 @@ class TestResnet(TestParallelExecutorBase): batch_size=batch_size, use_cuda=use_cuda, use_reduce=False, - optimizer=optimizer, + optimizer=optimizer(), enable_sequential_execution=True) reduce_first_loss_seq, reduce_last_loss_seq = self.check_network_convergence( @@ -254,7 +258,7 @@ class TestResnet(TestParallelExecutorBase): batch_size=batch_size, use_cuda=use_cuda, use_reduce=True, - optimizer=optimizer, + optimizer=optimizer(), enable_sequential_execution=True) for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq): @@ -277,7 +281,9 @@ class TestResnet(TestParallelExecutorBase): use_cuda=True, use_reduce=False, iter=20, - delta2=1e-6): + delta2=1e-6, + exec_type=ExecutorType.Default, + lr_scale=1.0): if use_cuda and not core.is_compiled_with_cuda(): return @@ -295,8 +301,9 @@ class TestResnet(TestParallelExecutorBase): batch_size=batch_size, use_cuda=use_cuda, use_reduce=use_reduce, - optimizer=optimizer, - use_parallel_executor=False) + optimizer=optimizer(), + use_parallel_executor=False, + exec_type=exec_type) parallel_first_loss, parallel_last_loss = self.check_network_convergence( model, feed_dict={"image": img, @@ -305,7 +312,8 @@ class TestResnet(TestParallelExecutorBase): batch_size=batch_size, use_cuda=use_cuda, use_reduce=use_reduce, - optimizer=optimizer) + optimizer=optimizer(lr_scale=lr_scale), + exec_type=exec_type) self.assertAlmostEquals( np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6) @@ -313,7 +321,14 @@ class TestResnet(TestParallelExecutorBase): np.mean(parallel_last_loss), single_last_loss[0], delta=delta2) def test_seresnext_with_learning_rate_decay(self): - self._check_resnet_convergence(model=SE_ResNeXt50Small, use_cuda=True) + if core.is_compiled_with_cuda(): + self._check_resnet_convergence( + model=SE_ResNeXt50Small, use_cuda=True) + self._check_resnet_convergence( + model=SE_ResNeXt50Small, + use_cuda=True, + exec_type=ExecutorType.ParallelGraph, + lr_scale=core.get_cuda_device_count()) self._check_resnet_convergence( model=SE_ResNeXt50Small, use_cuda=False, iter=2, delta2=1e-3) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py index 3827743908c1d76931572277323d1dd5ddd05523..b5ee72a24e60d39ca4d85c3d10548a584b577a0a 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py @@ -17,7 +17,7 @@ from __future__ import print_function import paddle.fluid as fluid import transformer_model import numpy as np -from parallel_executor_test_base import TestParallelExecutorBase +from parallel_executor_test_base import TestParallelExecutorBase, ExecutorType import unittest import paddle import paddle.fluid.core as core @@ -173,6 +173,10 @@ class TestTransformer(TestParallelExecutorBase): def test_main(self): if core.is_compiled_with_cuda(): self.check_network_convergence(transformer, use_cuda=True) + self.check_network_convergence( + transformer, + use_cuda=True, + exec_type=ExecutorType.ParallelGraph) self.check_network_convergence( transformer, use_cuda=True, enable_sequential_execution=True) self.check_network_convergence(transformer, use_cuda=False, iter=5)