提交 13e8b5bf 编写于 作者: Q Qiao Longfei

clear gradient before merge

上级 50601501
......@@ -47,6 +47,8 @@ static inline void MergeVars(const std::string &var_name,
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows *> inputs;
inputs.reserve(vars.size());
for (auto &var : vars) {
......@@ -71,6 +73,7 @@ void Communicator::SendThread() {
VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second;
std::vector<std::shared_ptr<Variable>> vars;
// TODO(qiao): need to be configurable
const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册