From 8e86721fe72845ae36ea9b9dc3576ac85b358336 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Fri, 6 Jul 2018 14:18:20 +0800 Subject: [PATCH] Fix data balance on single GPU --- paddle/fluid/framework/details/build_strategy.h | 2 +- .../framework/details/data_balance_op_handle.cc | 6 +++--- .../details/multi_devices_graph_builder.cc | 5 +++++ paddle/fluid/operators/read_op.cc | 10 +++++++--- .../fluid/tests/unittests/test_data_balance.py | 13 ++++++++++--- 5 files changed, 26 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 9c2c845c6ef..b2e5399e237 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -34,7 +34,7 @@ struct BuildStrategy { std::string debug_graphviz_path_{""}; - bool enable_data_balance_{true}; + bool enable_data_balance_{false}; }; } // namespace details diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index d07235df585..68896c8ac1b 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -86,9 +86,9 @@ std::vector> DataBalanceOpHandle::GetBalancePlan( } void DataBalanceOpHandle::RunImpl() { - if (places_.size() == 1) { - return; - } + PADDLE_ENFORCE_GT(places_.size(), 1, + "Data balance can only be enabled when the number of " + "places to run larger than 1."); auto in_var_handles = DynamicCast(inputs_); auto out_var_handles = DynamicCast(outputs_); PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 46d0c2769cb..b82c2ef4082 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -59,6 +59,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); + if (strategy_.enable_data_balance_ && places_.size() == 1) { + LOG(WARNING) << "It is no need to enable data balance when there is only " + "one place. enable_data_balance is set to False."; + strategy_.enable_data_balance_ = false; + } } void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 695d7ea83df..65fcce8bb01 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -92,9 +92,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("Reader", "(ReaderHolder) The executed reader."); AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); - AddAttr("throw_eof_exp", - "If set true, an exception will be thrown when the Reader " - "yields empty (which means there is no next data).") + AddAttr( + "throw_eof_exp", + "If set true, an exception will be thrown when the Reader " + "yields empty (which means there is no next data).\n" + "NOTES: This flag must be true always. It will be set to false" + " only when the data-balance is enabled in ParallelExecutor" + " and it is set by ParallelExecutor instance, not users.") .SetDefault(true); AddComment(R"DOC( Read Operator diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py index cffa3329ac5..6d810920d55 100644 --- a/python/paddle/fluid/tests/unittests/test_data_balance.py +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -103,8 +103,12 @@ class TestDataBalance(unittest.TestCase): exe = fluid.Executor(place) exe.run(startup_prog) + build_strategy = fluid.BuildStrategy() + build_strategy.enable_data_balance = True parallel_exe = fluid.ParallelExecutor( - use_cuda=self.use_cuda, main_program=main_prog) + use_cuda=self.use_cuda, + main_program=main_prog, + build_strategy=build_strategy) if (parallel_exe.device_count > self.batch_size): print("WARNING: Unittest TestDataBalance skipped. \ @@ -145,9 +149,12 @@ class TestDataBalance(unittest.TestCase): place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup_prog) - + build_strategy = fluid.BuildStrategy() + build_strategy.enable_data_balance = True parallel_exe = fluid.ParallelExecutor( - use_cuda=self.use_cuda, main_program=main_prog) + use_cuda=self.use_cuda, + main_program=main_prog, + build_strategy=build_strategy) if (parallel_exe.device_count > self.batch_size): print("WARNING: Unittest TestDataBalance skipped. \ -- GitLab