提交 13e8b5bf 编写于 作者: Q Qiao Longfei

clear gradient before merge

上级 50601501
...@@ -47,6 +47,8 @@ static inline void MergeVars(const std::string &var_name, ...@@ -47,6 +47,8 @@ static inline void MergeVars(const std::string &var_name,
} }
} else if (var0->IsType<framework::SelectedRows>()) { } else if (var0->IsType<framework::SelectedRows>()) {
auto *out_slr = out_var->GetMutable<framework::SelectedRows>(); auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows *> inputs; std::vector<const paddle::framework::SelectedRows *> inputs;
inputs.reserve(vars.size()); inputs.reserve(vars.size());
for (auto &var : vars) { for (auto &var : vars) {
...@@ -71,6 +73,7 @@ void Communicator::SendThread() { ...@@ -71,6 +73,7 @@ void Communicator::SendThread() {
VLOG(3) << "merge var " << var_name << " and send"; VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second; auto &var_queue = iter.second;
std::vector<std::shared_ptr<Variable>> vars; std::vector<std::shared_ptr<Variable>> vars;
// TODO(qiao): need to be configurable
const size_t max_merge_var_num = 20; const size_t max_merge_var_num = 20;
size_t merged_var_num = 0; size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) { while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册