提交 ff4317ce 编写于 作者: F fengjiayi

follow comments

上级 3606a306
...@@ -33,6 +33,8 @@ struct BuildStrategy { ...@@ -33,6 +33,8 @@ struct BuildStrategy {
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
std::string debug_graphviz_path_{""}; std::string debug_graphviz_path_{""};
bool enable_data_balance_{true};
}; };
} // namespace details } // namespace details
......
...@@ -73,7 +73,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan( ...@@ -73,7 +73,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) { for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) {
if (size_device_vec[src_idx][0] <= expected_device_size) { if (size_device_vec[src_idx][0] <= expected_device_size) {
++src_idx; ++src_idx;
PADDLE_ENFORCE_LT(src_idx, device_num - empty_num); PADDLE_ENFORCE_LT(
src_idx, device_num - empty_num,
"In current srategy an empty tensor should not be copy source.");
} }
size_device_vec[src_idx][0] -= expected_device_size; size_device_vec[src_idx][0] -= expected_device_size;
size_device_vec[dst_idx][0] += expected_device_size; size_device_vec[dst_idx][0] += expected_device_size;
...@@ -113,7 +115,9 @@ void DataBalanceOpHandle::RunImpl() { ...@@ -113,7 +115,9 @@ void DataBalanceOpHandle::RunImpl() {
if (data_idx == 0) { if (data_idx == 0) {
device_sizes.emplace_back(ins_size); device_sizes.emplace_back(ins_size);
} else { } else {
PADDLE_ENFORCE_EQ(ins_size, device_sizes.at(place_idx)); PADDLE_ENFORCE_EQ(
ins_size, device_sizes.at(place_idx),
"All data on the same device shall have the same batch size.");
} }
} }
const auto &balance_plan = GetBalancePlan(device_sizes); const auto &balance_plan = GetBalancePlan(device_sizes);
......
...@@ -216,7 +216,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -216,7 +216,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
// gradients. // gradients.
if (op->Type() == "read") { if (op->Type() == "read" && strategy_.enable_data_balance_) {
op->SetAttr("throw_eof_exp", false); op->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
const auto &data_var_names = op->Output("Out"); const auto &data_var_names = op->Output("Out");
......
...@@ -58,6 +58,7 @@ void OpHandleBase::Run(bool use_cuda) { ...@@ -58,6 +58,7 @@ void OpHandleBase::Run(bool use_cuda) {
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_NOT_NULL(waited_ctx);
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctxes_) { for (auto &dev_ctx : dev_ctxes_) {
PADDLE_ENFORCE_NOT_NULL(dev_ctx.second); PADDLE_ENFORCE_NOT_NULL(dev_ctx.second);
......
...@@ -643,7 +643,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -643,7 +643,11 @@ All parameter, weight, gradient are variables in Paddle.
[](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) { [](BuildStrategy &self, const std::string &path) {
self.debug_graphviz_path_ = path; self.debug_graphviz_path_ = path;
}); })
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; });
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &, const std::unordered_set<std::string> &,
......
...@@ -4,3 +4,5 @@ mnist_1.recordio ...@@ -4,3 +4,5 @@ mnist_1.recordio
mnist_2.recordio mnist_2.recordio
flowers.recordio flowers.recordio
wmt16.recordio wmt16.recordio
data_balance_test.recordio
data_balance_with_lod_test.recordio
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册