未验证 提交 8e86721f 编写于 作者: Y yuyang18

Fix data balance on single GPU

上级 d3a48484
......@@ -34,7 +34,7 @@ struct BuildStrategy {
std::string debug_graphviz_path_{""};
bool enable_data_balance_{true};
bool enable_data_balance_{false};
};
} // namespace details
......
......@@ -86,9 +86,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
}
void DataBalanceOpHandle::RunImpl() {
if (places_.size() == 1) {
return;
}
PADDLE_ENFORCE_GT(places_.size(), 1,
"Data balance can only be enabled when the number of "
"places to run larger than 1.");
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);
......
......@@ -59,6 +59,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
grad_names_.insert(GradVarName(p));
}
balance_vars_.resize(places_.size(), 0);
if (strategy_.enable_data_balance_ && places_.size() == 1) {
LOG(WARNING) << "It is no need to enable data balance when there is only "
"one place. enable_data_balance is set to False.";
strategy_.enable_data_balance_ = false;
}
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
......
......@@ -92,9 +92,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("Reader", "(ReaderHolder) The executed reader.");
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
AddAttr<bool>("throw_eof_exp",
"If set true, an exception will be thrown when the Reader "
"yields empty (which means there is no next data).")
AddAttr<bool>(
"throw_eof_exp",
"If set true, an exception will be thrown when the Reader "
"yields empty (which means there is no next data).\n"
"NOTES: This flag must be true always. It will be set to false"
" only when the data-balance is enabled in ParallelExecutor"
" and it is set by ParallelExecutor instance, not users.")
.SetDefault(true);
AddComment(R"DOC(
Read Operator
......
......@@ -103,8 +103,12 @@ class TestDataBalance(unittest.TestCase):
exe = fluid.Executor(place)
exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_data_balance = True
parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda, main_program=main_prog)
use_cuda=self.use_cuda,
main_program=main_prog,
build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size):
print("WARNING: Unittest TestDataBalance skipped. \
......@@ -145,9 +149,12 @@ class TestDataBalance(unittest.TestCase):
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_data_balance = True
parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda, main_program=main_prog)
use_cuda=self.use_cuda,
main_program=main_prog,
build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size):
print("WARNING: Unittest TestDataBalance skipped. \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册