未验证 提交 4700a08e 编写于 作者: J Jiabin Yang 提交者: GitHub

Support backward prune for eager intermidiate (#43111)

* support is empty

* fix error

* fix code error

* change to fake empty

* using fake empty first

* using fake empty first

* Support backward prune in fluid
上级 67497119
...@@ -2189,7 +2189,6 @@ static std::string GenerateSingleOpBase( ...@@ -2189,7 +2189,6 @@ static std::string GenerateSingleOpBase(
} }
VLOG(6) << "Generated Ins Map"; VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map // [Generation] Get Outs Map
std::string outs_contents_str = ""; std::string outs_contents_str = "";
for (auto iter : grad_outs) { for (auto iter : grad_outs) {
...@@ -2238,9 +2237,12 @@ static std::string GenerateSingleOpBase( ...@@ -2238,9 +2237,12 @@ static std::string GenerateSingleOpBase(
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name); size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::TrySyncToVars(hooked_grads[%d]) },"; " if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ \n %s.insert({ \"%s\", "
"egr::EagerUtils::TrySyncToVars(hooked_grads[%d])});} \n ";
outs_contents_str += paddle::string::Sprintf( outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grads_position); GRAD_OUTS_CONTENT_TEMPLATE, grads_position, grads_position,
outs_name, grad_output_name, grads_position);
} else { } else {
if (dispensable_input_name_set.count(fwd_name) && if (dispensable_input_name_set.count(fwd_name) &&
...@@ -2251,18 +2253,20 @@ static std::string GenerateSingleOpBase( ...@@ -2251,18 +2253,20 @@ static std::string GenerateSingleOpBase(
if (duplicable_input_name_set.count(fwd_name) && if (duplicable_input_name_set.count(fwd_name) &&
!is_op_base_per_duplicable_input) { !is_op_base_per_duplicable_input) {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::CreateVars( " " if(!out_metas[%d].empty()){ %s.insert({ \"%s\", "
"this->OutputMeta()[%d].size() ) },"; "egr::EagerUtils::CreateVars(out_metas[%d].size())});} \n ";
outs_contents_str += paddle::string::Sprintf( outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); GRAD_OUTS_CONTENT_TEMPLATE, fwd_input_position, outs_name,
grad_output_name, fwd_input_position);
} else { } else {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", " " if((!out_metas[%d].empty()) && "
"(!(out_metas[%d][0].IsStopGradient()))){ %s.insert({ \"%s\", "
"{std::make_shared<egr::EagerVariable>(egr::Controller::Instance(" "{std::make_shared<egr::EagerVariable>(egr::Controller::Instance("
")." ").GenerateUniqueName())}});} \n ";
"GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf( outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name); GRAD_OUTS_CONTENT_TEMPLATE, fwd_input_position,
fwd_input_position, outs_name, grad_output_name);
} }
} }
} else { } else {
...@@ -2272,16 +2276,15 @@ static std::string GenerateSingleOpBase( ...@@ -2272,16 +2276,15 @@ static std::string GenerateSingleOpBase(
grad_output_name)); grad_output_name));
} }
} }
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_OUTS_MAP_TEMPLATE = const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, " " std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerVariable>>> %s = { " "std::vector<std::shared_ptr<egr::EagerVariable>>> %s;\n";
"%s };\n"; std::string outs_map_str =
std::string outs_map_str = paddle::string::Sprintf( paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_name);
BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str);
generated_grad_function_body += outs_map_str; generated_grad_function_body += outs_map_str;
generated_grad_function_body += outs_contents_str;
generated_grad_function_body += "\n"; generated_grad_function_body += "\n";
for (auto iter : grad_outs) { for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first; const std::string& grad_output_name = iter.first;
...@@ -2296,18 +2299,23 @@ static std::string GenerateSingleOpBase( ...@@ -2296,18 +2299,23 @@ static std::string GenerateSingleOpBase(
!is_op_base_per_duplicable_input) { !is_op_base_per_duplicable_input) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE = const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE =
" if(%s.size() > 0) %s[\"%s\"] = egr::EagerUtils::CreateVars( " " if((%s.size() > 0) && (!out_metas[%d].empty()) && "
"this->OutputMeta()[%d].size() );\n"; "(!out_metas[%d][0].IsStopGradient())) %s[\"%s\"] = "
"egr::EagerUtils::CreateVars( "
"out_metas[%d].size() );\n";
generated_grad_function_body += paddle::string::Sprintf( generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name, DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name,
grad_output_name, fwd_input_position); grad_output_name, fwd_input_position);
} else { } else {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE = const char* DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE =
" if(%s.defined()) %s[\"%s\"] = " " if(%s.defined() && (!out_metas[%d].empty()) && "
"(!out_metas[%d][0].IsStopGradient())) %s[\"%s\"] = "
"{std::make_shared<egr::EagerVariable>(egr::Controller::" "{std::make_shared<egr::EagerVariable>(egr::Controller::"
"Instance().GenerateUniqueName())};\n"; "Instance().GenerateUniqueName())};\n";
generated_grad_function_body += paddle::string::Sprintf( generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name, outs_name, DISPENSABLE_GRAD_OUTS_FWD_CONTENT_TEMPLATE, fwd_name,
fwd_input_position, fwd_input_position, outs_name,
grad_output_name); grad_output_name);
} }
} }
...@@ -2387,16 +2395,20 @@ static std::string GenerateSingleOpBase( ...@@ -2387,16 +2395,20 @@ static std::string GenerateSingleOpBase(
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (!is_op_base_per_duplicable_input) { if (!is_op_base_per_duplicable_input) {
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; " if (%s.find(\"%s\") != %s.end()) { outputs[%d] = "
"egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n";
outputs_str += paddle::string::Sprintf( outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name); BWD_OUTPUT_TEMPLATE, outs_name, grad_out_name, outs_name,
fwd_input_position, outs_name, grad_out_name);
} else { } else {
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" " " "
"if (%s.find(\"%s\") != %s.end()) { "
"outputs[0].emplace_back(egr::EagerUtils::GetOutputs(%s[\"%s\"])[0]" "outputs[0].emplace_back(egr::EagerUtils::GetOutputs(%s[\"%s\"])[0]"
");\n"; "); }\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outs_name, outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outs_name,
grad_out_name); grad_out_name, outs_name,
outs_name, grad_out_name);
} }
num_appended_outputs++; num_appended_outputs++;
} else { } else {
...@@ -2415,9 +2427,11 @@ static std::string GenerateSingleOpBase( ...@@ -2415,9 +2427,11 @@ static std::string GenerateSingleOpBase(
if (fwd_outputs_name_pos_map.count(fwd_name)) { if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; " if (%s.find(\"%s\") != %s.end()) { outputs[%d] = "
"egr::EagerUtils::GetOutputs(%s[\"%s\"]); }\n";
outputs_str += paddle::string::Sprintf( outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, num_appended_outputs, outs_name, grad_out_name); BWD_OUTPUT_TEMPLATE, outs_name, grad_out_name, outs_name,
num_appended_outputs, outs_name, grad_out_name);
num_appended_outputs++; num_appended_outputs++;
} }
} }
...@@ -2550,6 +2564,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -2550,6 +2564,7 @@ static std::string GenerateGradNodeCCContents(
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, " " paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> hooked_grads = " "egr::kSlotSmallVectorSize> hooked_grads = "
"GradNode%s::ApplyGradientHooks(grads);\n" "GradNode%s::ApplyGradientHooks(grads);\n"
" const auto& out_metas = OutputMeta();\n"
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, " " paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> outputs(%d);\n" "egr::kSlotSmallVectorSize> outputs(%d);\n"
" %s\n" " %s\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册