未验证 提交 160ddc98 编写于 作者: G gongweibao 提交者: GitHub

Regroup fusion by date type. (#18496)

上级 6f6ecbec
...@@ -206,24 +206,23 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -206,24 +206,23 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
details::GroupParamsAndGrads *group_params_grads) const { details::GroupParamsAndGrads *group_params_grads) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_params_grads); SetGroupAccordingToLayers(var_nodes, params_grads, group_params_grads);
SetGroupAccordingToMemorySize(var_nodes, group_params_grads); SetGroupAccordingToMemorySize(var_nodes, group_params_grads);
ReGroupByDtype(var_nodes, params_grads, group_params_grads);
} }
void SetGroupAccordingToLayers( void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes, const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads, const details::ParamsAndGrads &params_grads,
details::GroupParamsAndGrads *group_params_grads) const { details::GroupParamsAndGrads *group_params_grads) const {
using var_dtype = std::pair<std::string, proto::VarType::Type>; std::map<std::string, size_t> var_idx;
std::map<var_dtype, size_t> var_idx;
for (size_t i = 0; i < params_grads.size(); ++i) { for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of("."); auto pos = params_grads[i].first.find_first_of(".");
auto dtype = GetDtypeOfVar(var_nodes, params_grads[i].second); std::string var_key;
var_dtype var_key;
if (pos == std::string::npos) { if (pos == std::string::npos) {
var_key = std::make_pair(params_grads[i].first, dtype); var_key = params_grads[i].first;
} else { } else {
var_key = std::make_pair(params_grads[i].first.substr(0, pos), dtype); var_key = params_grads[i].first.substr(0, pos);
} }
size_t idx = 0; size_t idx = 0;
...@@ -289,9 +288,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -289,9 +288,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
local_group_params_grads.emplace_back(); local_group_params_grads.emplace_back();
auto &group_p_g = local_group_params_grads.back(); auto &group_p_g = local_group_params_grads.back();
auto &grad_name = group_params_grads->at(j).front().second;
auto var_type = GetDtypeOfVar(var_nodes, grad_name);
size_t local_group_memory_size = 0; size_t local_group_memory_size = 0;
while (j < group_params_grads->size()) { while (j < group_params_grads->size()) {
std::for_each( std::for_each(
...@@ -330,12 +326,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -330,12 +326,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
group_memory_size) { group_memory_size) {
break; break;
} }
auto next_var_type =
GetDtypeOfVar(var_nodes, group_params_grads->at(j).front().second);
if (next_var_type != var_type) {
break;
}
} }
} }
...@@ -348,6 +338,55 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -348,6 +338,55 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
} }
void ReGroupByDtype(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupParamsAndGrads *group_params_grads) const {
if (IsUnifiedDtype(params_grads, var_nodes)) {
VLOG(1) << "needn't regroup fusion params_grads";
return;
}
details::GroupParamsAndGrads new_group_params_grads;
for (auto &group_p_g : *group_params_grads) {
std::map<proto::VarType::Type, size_t> type_idx;
details::GroupParamsAndGrads local_group_params_grads;
for (auto &p_g : group_p_g) {
auto dtype = GetDtypeOfVar(var_nodes, p_g.second);
size_t idx = 0;
auto var_idx_iter = type_idx.find(dtype);
if (var_idx_iter != type_idx.end()) {
idx = var_idx_iter->second;
} else {
local_group_params_grads.emplace_back();
idx = local_group_params_grads.size() - 1;
type_idx[dtype] = idx;
}
auto &local = local_group_params_grads.at(idx);
local.emplace_back(p_g);
}
VLOG(10) << "local_group_params_grads size:"
<< local_group_params_grads.size();
new_group_params_grads.insert(new_group_params_grads.end(),
local_group_params_grads.begin(),
local_group_params_grads.end());
}
std::swap(*group_params_grads, new_group_params_grads);
if (VLOG_IS_ON(10)) {
VLOG(10) << string::Sprintf("ReGroupByDtype(memory_size: %f MB, %u):",
GetFuseParameterMemorySize(),
GetFuseParameterGroupsSize());
PrintGroupInfo(var_nodes, group_params_grads);
}
}
proto::VarType::Type GetDtypeOfVar( proto::VarType::Type GetDtypeOfVar(
const std::unordered_map<std::string, Node *> &var_nodes, const std::unordered_map<std::string, Node *> &var_nodes,
const std::string &name) const { const std::string &name) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册