diff --git a/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc b/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc index 3fc3acce47bdfb5d01cf8a5e5904f052d3e67f29..001da5686feba6d4613648b2b7da50bd7b563c83 100644 --- a/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc +++ b/paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc @@ -206,24 +206,23 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { details::GroupParamsAndGrads *group_params_grads) const { SetGroupAccordingToLayers(var_nodes, params_grads, group_params_grads); SetGroupAccordingToMemorySize(var_nodes, group_params_grads); + ReGroupByDtype(var_nodes, params_grads, group_params_grads); } void SetGroupAccordingToLayers( const std::unordered_map &var_nodes, const details::ParamsAndGrads ¶ms_grads, details::GroupParamsAndGrads *group_params_grads) const { - using var_dtype = std::pair; - std::map var_idx; + std::map var_idx; for (size_t i = 0; i < params_grads.size(); ++i) { auto pos = params_grads[i].first.find_first_of("."); - auto dtype = GetDtypeOfVar(var_nodes, params_grads[i].second); - var_dtype var_key; + std::string var_key; if (pos == std::string::npos) { - var_key = std::make_pair(params_grads[i].first, dtype); + var_key = params_grads[i].first; } else { - var_key = std::make_pair(params_grads[i].first.substr(0, pos), dtype); + var_key = params_grads[i].first.substr(0, pos); } size_t idx = 0; @@ -289,9 +288,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { local_group_params_grads.emplace_back(); auto &group_p_g = local_group_params_grads.back(); - auto &grad_name = group_params_grads->at(j).front().second; - auto var_type = GetDtypeOfVar(var_nodes, grad_name); - size_t local_group_memory_size = 0; while (j < group_params_grads->size()) { std::for_each( @@ -330,12 +326,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { group_memory_size) { break; } - - auto next_var_type = - GetDtypeOfVar(var_nodes, group_params_grads->at(j).front().second); - if (next_var_type != var_type) { - break; - } } } @@ -348,6 +338,55 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { } } + void ReGroupByDtype( + const std::unordered_map &var_nodes, + const details::ParamsAndGrads ¶ms_grads, + details::GroupParamsAndGrads *group_params_grads) const { + if (IsUnifiedDtype(params_grads, var_nodes)) { + VLOG(1) << "needn't regroup fusion params_grads"; + return; + } + + details::GroupParamsAndGrads new_group_params_grads; + + for (auto &group_p_g : *group_params_grads) { + std::map type_idx; + details::GroupParamsAndGrads local_group_params_grads; + + for (auto &p_g : group_p_g) { + auto dtype = GetDtypeOfVar(var_nodes, p_g.second); + + size_t idx = 0; + auto var_idx_iter = type_idx.find(dtype); + if (var_idx_iter != type_idx.end()) { + idx = var_idx_iter->second; + } else { + local_group_params_grads.emplace_back(); + idx = local_group_params_grads.size() - 1; + type_idx[dtype] = idx; + } + + auto &local = local_group_params_grads.at(idx); + local.emplace_back(p_g); + } + + VLOG(10) << "local_group_params_grads size:" + << local_group_params_grads.size(); + new_group_params_grads.insert(new_group_params_grads.end(), + local_group_params_grads.begin(), + local_group_params_grads.end()); + } + + std::swap(*group_params_grads, new_group_params_grads); + + if (VLOG_IS_ON(10)) { + VLOG(10) << string::Sprintf("ReGroupByDtype(memory_size: %f MB, %u):", + GetFuseParameterMemorySize(), + GetFuseParameterGroupsSize()); + PrintGroupInfo(var_nodes, group_params_grads); + } + } + proto::VarType::Type GetDtypeOfVar( const std::unordered_map &var_nodes, const std::string &name) const {