未验证 提交 4700a08e 编写于 作者: J Jiabin Yang 提交者: GitHub

Support backward prune for eager intermidiate (#43111)

* support is empty

* fix error

* fix code error

* change to fake empty

* using fake empty first

* using fake empty first

* Support backward prune in fluid
上级 67497119
......@@ -2189,7 +2189,6 @@ static std::string GenerateSingleOpBase(
}
VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map
std::string outs_contents_str = "";
for (auto iter : grad_outs) {
......@@ -2238,9 +2237,12 @@ static std::string GenerateSingleOpBase(
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::TrySyncToVars(hooked_grads[%d]) },";
" if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ \n %s.insert({ \"%s\", "
"egr::EagerUtils::TrySyncToVars(hooked_grads[%d])});} \n ";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grads_position);
GRAD_OUTS_CONTENT_TEMPLATE, grads_position, grads_position,
outs_name, grad_output_name, grads_position);
} else {
if (dispensable_input_name_set.count(fwd_name) &&
......@@ -2251,18 +2253,20 @@ static std::string GenerateSingleOpBase(
if (duplicable_input_name_set.count(fwd_name) &&
!is_op_base_per_duplicable_input) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::CreateVars( "
"this->OutputMeta()[%d].size() ) },";
" if(!out_metas[%d].empty()){ %s.insert({ \"%s\", "
"egr::EagerUtils::CreateVars(out_metas[%d].size())});} \n ";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
GRAD_OUTS_CONTENT_TEMPLATE, fwd_input_position, outs_name,
grad_output_name, fwd_input_position);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
" if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", "
"{std::make_shared<egr::EagerVariable>(egr::Controller::Instance("
")."
"GenerateUniqueName())}},";
").GenerateUniqueName())}});} \n ";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name);
GRAD_OUTS_CONTENT_TEMPLATE, fwd_input_position,
fwd_input_position, outs_name, grad_output_name);
}
}
} else {
......@@ -2272,16 +2276,15 @@ static std::string GenerateSingleOpBase(
grad_output_name));
}
}
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerVariable>>> %s = { "
"%s };\n";
std::string outs_map_str = paddle::string::Sprintf(
BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str);
"std::vector<std::shared_ptr<egr::EagerVariable>>> %s;\n";
std::string outs_map_str =
paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_name);
generated_grad_function_body += outs_map_str;
generated_grad_function_body += outs_contents_str;
generated_grad_function_body += "\n";
for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first;
......@@ -2296,18 +2299,23 @@ static std::string GenerateSingleOpBase(
!is_op_base_per_duplicable_input) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE =
" if(%s.size() > 0) %s[\"%s\"] = egr::EagerUtils::CreateVars( "
"this->OutputMeta()[%d].size() );\n";
" if((%s.size() > 0) && (!out_metas[%d].empty()) && "
"(!out_metas[%d][0].IsStopGradient())) %s[\"%s\"] = "
"egr::EagerUtils::CreateVars( "
"out_metas[%d].size() );\n";
generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name,
grad_output_name, fwd_input_position);
} else {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE =
" if(%s.defined()) %s[\"%s\"] = "
" if(%s.defined() && (!out_metas[%d].empty()) && "
"(!out_metas[%d][0].IsStopGradient())) %s[\"%s\"] = "
"{std::make_shared<egr::EagerVariable>(egr::Controller::"
"Instance().GenerateUniqueName())};\n";
generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name,
DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name,
fwd_input_position, fwd_input_position, outs_name,
grad_output_name);
}
}
......@@ -2387,16 +2395,20 @@ static std::string GenerateSingleOpBase(
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (!is_op_base_per_duplicable_input) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
" if (%s.find(\"%s\") != %s.end()) { outputs[%d] = "
"egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name);
BWD_OUTPUT_TEMPLATE, outs_name, grad_out_name, outs_name,
fwd_input_position, outs_name, grad_out_name);
} else {
const char* BWD_OUTPUT_TEMPLATE =
" "
"if (%s.find(\"%s\") != %s.end()) { "
"outputs[0].emplace_back(egr::EagerUtils::GetOutputs(%s[\"%s\"])[0]"
");\n";
"); }\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outs_name,
grad_out_name);
grad_out_name, outs_name,
outs_name, grad_out_name);
}
num_appended_outputs++;
} else {
......@@ -2415,9 +2427,11 @@ static std::string GenerateSingleOpBase(
if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
" if (%s.find(\"%s\") != %s.end()) { outputs[%d] = "
"egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, num_appended_outputs, outs_name, grad_out_name);
BWD_OUTPUT_TEMPLATE, outs_name, grad_out_name, outs_name,
num_appended_outputs, outs_name, grad_out_name);
num_appended_outputs++;
}
}
......@@ -2550,6 +2564,7 @@ static std::string GenerateGradNodeCCContents(
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> hooked_grads = "
"GradNode%s::ApplyGradientHooks(grads);\n"
" const auto& out_metas = OutputMeta();\n"
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> outputs(%d);\n"
" %s\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册