diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 9c2c845c6efb206fb1ad5150189430b9a6fe9ea3..b2e5399e2376a86c1cd310b29c768832665af87f 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -34,7 +34,7 @@ struct BuildStrategy { std::string debug_graphviz_path_{""}; - bool enable_data_balance_{true}; + bool enable_data_balance_{false}; }; } // namespace details diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index d07235df5856591f8ad707c86fa5b3b65868c3d1..68896c8ac1bae7d4bfcfa79cc8ec5c26bf2d93ee 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -86,9 +86,9 @@ std::vector> DataBalanceOpHandle::GetBalancePlan( } void DataBalanceOpHandle::RunImpl() { - if (places_.size() == 1) { - return; - } + PADDLE_ENFORCE_GT(places_.size(), 1, + "Data balance can only be enabled when the number of " + "places to run larger than 1."); auto in_var_handles = DynamicCast(inputs_); auto out_var_handles = DynamicCast(outputs_); PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 46d0c2769cb334f5cb75ae0ef5e48da45448c48f..b82c2ef4082110f1621eb38d50361396511a4825 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -59,6 +59,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); + if (strategy_.enable_data_balance_ && places_.size() == 1) { + LOG(WARNING) << "It is no need to enable data balance when there is only " + "one place. enable_data_balance is set to False."; + strategy_.enable_data_balance_ = false; + } } void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 695d7ea83df952d9f2212cc0aaca5c90c7b47ee7..65fcce8bb019965a805ad09d50be0aba64e4f24e 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -92,9 +92,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("Reader", "(ReaderHolder) The executed reader."); AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); - AddAttr("throw_eof_exp", - "If set true, an exception will be thrown when the Reader " - "yields empty (which means there is no next data).") + AddAttr( + "throw_eof_exp", + "If set true, an exception will be thrown when the Reader " + "yields empty (which means there is no next data).\n" + "NOTES: This flag must be true always. It will be set to false" + " only when the data-balance is enabled in ParallelExecutor" + " and it is set by ParallelExecutor instance, not users.") .SetDefault(true); AddComment(R"DOC( Read Operator diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py index cffa3329ac556dc77f3cb508b807cbd49bb974f7..6d810920d55ccf069ff408c553069e8f5e590271 100644 --- a/python/paddle/fluid/tests/unittests/test_data_balance.py +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -103,8 +103,12 @@ class TestDataBalance(unittest.TestCase): exe = fluid.Executor(place) exe.run(startup_prog) + build_strategy = fluid.BuildStrategy() + build_strategy.enable_data_balance = True parallel_exe = fluid.ParallelExecutor( - use_cuda=self.use_cuda, main_program=main_prog) + use_cuda=self.use_cuda, + main_program=main_prog, + build_strategy=build_strategy) if (parallel_exe.device_count > self.batch_size): print("WARNING: Unittest TestDataBalance skipped. \ @@ -145,9 +149,12 @@ class TestDataBalance(unittest.TestCase): place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup_prog) - + build_strategy = fluid.BuildStrategy() + build_strategy.enable_data_balance = True parallel_exe = fluid.ParallelExecutor( - use_cuda=self.use_cuda, main_program=main_prog) + use_cuda=self.use_cuda, + main_program=main_prog, + build_strategy=build_strategy) if (parallel_exe.device_count > self.batch_size): print("WARNING: Unittest TestDataBalance skipped. \