diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 64e83acb4dc1995800c4ca3caf81668b24a7c9fe..9c2c845c6efb206fb1ad5150189430b9a6fe9ea3 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -33,6 +33,8 @@ struct BuildStrategy { GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; std::string debug_graphviz_path_{""}; + + bool enable_data_balance_{true}; }; } // namespace details diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index f8d431ef2a35823ed853b543fcd1d9b6064a4058..b914851fe0add74f6d85589f4686224b668b8064 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -73,7 +73,9 @@ std::vector> DataBalanceOpHandle::GetBalancePlan( for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) { if (size_device_vec[src_idx][0] <= expected_device_size) { ++src_idx; - PADDLE_ENFORCE_LT(src_idx, device_num - empty_num); + PADDLE_ENFORCE_LT( + src_idx, device_num - empty_num, + "In current srategy an empty tensor should not be copy source."); } size_device_vec[src_idx][0] -= expected_device_size; size_device_vec[dst_idx][0] += expected_device_size; @@ -113,7 +115,9 @@ void DataBalanceOpHandle::RunImpl() { if (data_idx == 0) { device_sizes.emplace_back(ins_size); } else { - PADDLE_ENFORCE_EQ(ins_size, device_sizes.at(place_idx)); + PADDLE_ENFORCE_EQ( + ins_size, device_sizes.at(place_idx), + "All data on the same device shall have the same batch size."); } } const auto &balance_plan = GetBalancePlan(device_sizes); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index edfefb8231f969d3f6aa1b3cb13a341d9a25aaf4..46d0c2769cb334f5cb75ae0ef5e48da45448c48f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -216,7 +216,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } else { // This op runs on all devices, and its output may have parameter's // gradients. - if (op->Type() == "read") { + if (op->Type() == "read" && strategy_.enable_data_balance_) { op->SetAttr("throw_eof_exp", false); CreateComputationalOps(&result, *op, places_.size()); const auto &data_var_names = op->Output("Out"); diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 856124875d55e65428a7fb23e402c0d311900724..3560fabb424375a770432586fe7c8e51210b3d0c 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -58,6 +58,7 @@ void OpHandleBase::Run(bool use_cuda) { void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { #ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_NOT_NULL(waited_ctx); if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { for (auto &dev_ctx : dev_ctxes_) { PADDLE_ENFORCE_NOT_NULL(dev_ctx.second); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 36d080996831d4ad90d92baeafbe964693e2332a..9fc647a7d2a2bdfbaeeb91b00b4183f5c80b5aba 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -643,7 +643,11 @@ All parameter, weight, gradient are variables in Paddle. [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](BuildStrategy &self, const std::string &path) { self.debug_graphviz_path_ = path; - }); + }) + .def_property( + "enable_data_balance", + [](const BuildStrategy &self) { return self.enable_data_balance_; }, + [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }); pe.def(py::init &, const std::unordered_set &, diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index 3538a9c2009bb133609153427981fb66974377fa..b1e8fda03aa42f5f7528eafb46c16d55b868bae5 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -4,3 +4,5 @@ mnist_1.recordio mnist_2.recordio flowers.recordio wmt16.recordio +data_balance_test.recordio +data_balance_with_lod_test.recordio