提交 2e320079 编写于 作者: F fengjiayi

fix bugs

上级 5b4f2830
......@@ -81,8 +81,7 @@ void DataBalanceOpHandle::RunImpl() {
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
int data_num = in_var_handles.size() / places_.size();
WaitInputVarGenerated();
std::vector<std::vector<LoDTensor *>> lod_tensors;
std::vector<std::vector<LoDTensor *>> lod_tensors(data_num);
std::vector<int> device_sizes;
for (int i = 0; i < static_cast<int>(in_var_handles.size()); ++i) {
PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_,
......@@ -105,7 +104,6 @@ void DataBalanceOpHandle::RunImpl() {
}
}
const auto &balance_plan = GetBalancePlan(device_sizes);
for (const auto &trans : balance_plan) {
for (int data_idx = 0; data_idx < data_num; ++data_idx) {
LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]];
......
......@@ -41,8 +41,8 @@ struct DataBalanceOpHandle : public OpHandleBase {
std::vector<std::array<int, 3>> GetBalancePlan(
const std::vector<int> &batch_size_per_device);
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> places_;
};
} // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册