未验证 提交 8e86721f 编写于 作者: Y yuyang18

Fix data balance on single GPU

上级 d3a48484
...@@ -34,7 +34,7 @@ struct BuildStrategy { ...@@ -34,7 +34,7 @@ struct BuildStrategy {
std::string debug_graphviz_path_{""}; std::string debug_graphviz_path_{""};
bool enable_data_balance_{true}; bool enable_data_balance_{false};
}; };
} // namespace details } // namespace details
......
...@@ -86,9 +86,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan( ...@@ -86,9 +86,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
} }
void DataBalanceOpHandle::RunImpl() { void DataBalanceOpHandle::RunImpl() {
if (places_.size() == 1) { PADDLE_ENFORCE_GT(places_.size(), 1,
return; "Data balance can only be enabled when the number of "
} "places to run larger than 1.");
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_); auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0); PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);
......
...@@ -59,6 +59,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -59,6 +59,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
balance_vars_.resize(places_.size(), 0); balance_vars_.resize(places_.size(), 0);
if (strategy_.enable_data_balance_ && places_.size() == 1) {
LOG(WARNING) << "It is no need to enable data balance when there is only "
"one place. enable_data_balance is set to False.";
strategy_.enable_data_balance_ = false;
}
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
......
...@@ -92,9 +92,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -92,9 +92,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("Reader", "(ReaderHolder) The executed reader."); AddInput("Reader", "(ReaderHolder) The executed reader.");
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
AddAttr<bool>("throw_eof_exp", AddAttr<bool>(
"If set true, an exception will be thrown when the Reader " "throw_eof_exp",
"yields empty (which means there is no next data).") "If set true, an exception will be thrown when the Reader "
"yields empty (which means there is no next data).\n"
"NOTES: This flag must be true always. It will be set to false"
" only when the data-balance is enabled in ParallelExecutor"
" and it is set by ParallelExecutor instance, not users.")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Read Operator Read Operator
......
...@@ -103,8 +103,12 @@ class TestDataBalance(unittest.TestCase): ...@@ -103,8 +103,12 @@ class TestDataBalance(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_data_balance = True
parallel_exe = fluid.ParallelExecutor( parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda, main_program=main_prog) use_cuda=self.use_cuda,
main_program=main_prog,
build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size): if (parallel_exe.device_count > self.batch_size):
print("WARNING: Unittest TestDataBalance skipped. \ print("WARNING: Unittest TestDataBalance skipped. \
...@@ -145,9 +149,12 @@ class TestDataBalance(unittest.TestCase): ...@@ -145,9 +149,12 @@ class TestDataBalance(unittest.TestCase):
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_data_balance = True
parallel_exe = fluid.ParallelExecutor( parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda, main_program=main_prog) use_cuda=self.use_cuda,
main_program=main_prog,
build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size): if (parallel_exe.device_count > self.batch_size):
print("WARNING: Unittest TestDataBalance skipped. \ print("WARNING: Unittest TestDataBalance skipped. \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册