未验证 提交 9bb39d48 编写于 作者: J Jiabin Yang 提交者: GitHub

support prune (#43250)

上级 71a63f0a
......@@ -1206,22 +1206,37 @@ static std::string GenerateGradNodeCreationContent(
if (!input.duplicable()) {
compute_require_grad_args += ", " + input_autograd_name;
size_t input_position = fwd_inputs_name_pos_map.at(input_name);
bool found_target_name = false;
for (const auto& iter : op_base_infos) {
const auto& grad_outs_slot_map = iter.GetGradOutsSlotnameMap();
for (auto iter : grad_outs_slot_map) {
if ((!found_target_name) && (input_name == iter.second)) {
const char* SET_GRAD_OUT_META_TEMPLATE =
" grad_node->SetGradOutMeta(%s, %d);\n";
grad_node_creation_str +=
paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE,
LegalizeVarName(input_name), input_position);
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_OUT_META_TEMPLATE, LegalizeVarName(input_name),
input_position);
found_target_name = true;
}
}
}
} else {
compute_require_grad_args += ", &" + input_autograd_name;
size_t input_position = fwd_inputs_name_pos_map.at(input_name);
bool found_target_name = false;
for (const auto& iter : op_base_infos) {
const auto& grad_outs_slot_map = iter.GetGradOutsSlotnameMap();
for (auto iter : grad_outs_slot_map) {
if ((!found_target_name) && (input_name == iter.second)) {
const char* SET_GRAD_OUT_META_TEMPLATE =
" grad_node->SetGradOutMeta(%s, %d);\n";
grad_node_creation_str +=
paddle::string::Sprintf(SET_GRAD_OUT_META_TEMPLATE,
LegalizeVarName(input_name), input_position);
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_OUT_META_TEMPLATE, LegalizeVarName(input_name),
input_position);
found_target_name = true;
}
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册