提交 ff4317ce 编写于 作者: F fengjiayi

follow comments

上级 3606a306
......@@ -33,6 +33,8 @@ struct BuildStrategy {
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
std::string debug_graphviz_path_{""};
bool enable_data_balance_{true};
};
} // namespace details
......
......@@ -73,7 +73,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) {
if (size_device_vec[src_idx][0] <= expected_device_size) {
++src_idx;
PADDLE_ENFORCE_LT(src_idx, device_num - empty_num);
PADDLE_ENFORCE_LT(
src_idx, device_num - empty_num,
"In current srategy an empty tensor should not be copy source.");
}
size_device_vec[src_idx][0] -= expected_device_size;
size_device_vec[dst_idx][0] += expected_device_size;
......@@ -113,7 +115,9 @@ void DataBalanceOpHandle::RunImpl() {
if (data_idx == 0) {
device_sizes.emplace_back(ins_size);
} else {
PADDLE_ENFORCE_EQ(ins_size, device_sizes.at(place_idx));
PADDLE_ENFORCE_EQ(
ins_size, device_sizes.at(place_idx),
"All data on the same device shall have the same batch size.");
}
}
const auto &balance_plan = GetBalancePlan(device_sizes);
......
......@@ -216,7 +216,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else {
// This op runs on all devices, and its output may have parameter's
// gradients.
if (op->Type() == "read") {
if (op->Type() == "read" && strategy_.enable_data_balance_) {
op->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, *op, places_.size());
const auto &data_var_names = op->Output("Out");
......
......@@ -58,6 +58,7 @@ void OpHandleBase::Run(bool use_cuda) {
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_NOT_NULL(waited_ctx);
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctxes_) {
PADDLE_ENFORCE_NOT_NULL(dev_ctx.second);
......
......@@ -643,7 +643,11 @@ All parameter, weight, gradient are variables in Paddle.
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) {
self.debug_graphviz_path_ = path;
});
})
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; });
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &,
......
......@@ -4,3 +4,5 @@ mnist_1.recordio
mnist_2.recordio
flowers.recordio
wmt16.recordio
data_balance_test.recordio
data_balance_with_lod_test.recordio
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册