未验证 提交 ed2886de 编写于 作者: P pangyoki 提交者: GitHub

support backward inplace in eager fluid dygraph mode (#43054)

* support backward inplace in eager fluid mode

* fix

* fix

* optimize format

* little change
上级 3d56d419
...@@ -231,6 +231,15 @@ class GradNodeGenerationInfo { ...@@ -231,6 +231,15 @@ class GradNodeGenerationInfo {
return &no_need_buffer_ins_; return &no_need_buffer_ins_;
} }
const std::unordered_map<std::string, std::string>& GetBackwardInplaceMap()
const {
return backward_inplace_map_;
}
std::unordered_map<std::string, std::string>*
GetMutableBackwardInplaceMap() {
return &backward_inplace_map_;
}
private: private:
std::string op_base_type_; std::string op_base_type_;
std::map<std::string, std::string> grad_outs_slotname_map_; std::map<std::string, std::string> grad_outs_slotname_map_;
...@@ -244,6 +253,7 @@ class GradNodeGenerationInfo { ...@@ -244,6 +253,7 @@ class GradNodeGenerationInfo {
grad_outs_; grad_outs_;
paddle::framework::AttributeMap grad_attrs_; paddle::framework::AttributeMap grad_attrs_;
std::unordered_set<std::string> no_need_buffer_ins_; std::unordered_set<std::string> no_need_buffer_ins_;
std::unordered_map<std::string, std::string> backward_inplace_map_;
}; };
public: public:
...@@ -979,6 +989,12 @@ static bool CollectGradInformationFromOpInfo( ...@@ -979,6 +989,12 @@ static bool CollectGradInformationFromOpInfo(
*(*op_base_infos)[index].GetMutableNoNeedBufferInputs() = *(*op_base_infos)[index].GetMutableNoNeedBufferInputs() =
inferer(g_ins, g_outs, *op_base_grad_attrs); inferer(g_ins, g_outs, *op_base_grad_attrs);
} }
auto& infer_backward_inplace = op_base.Info().infer_inplace_;
if (infer_backward_inplace) {
*(*op_base_infos)[index].GetMutableBackwardInplaceMap() =
infer_backward_inplace(true);
}
} }
/* ------ Slot Name Matching ---- */ /* ------ Slot Name Matching ---- */
...@@ -1005,7 +1021,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1005,7 +1021,7 @@ static std::string GenerateGradNodeCreationContent(
const ForwardGenerationInfo& fwd_info, const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info, const GradNodeGenerationInfo& bwd_info,
const std::string& trace_op_body_str, const std::string& trace_op_body_str,
std::map<std::string, std::string> inplace_map = {}) { std::map<std::string, std::string> forward_inplace_map = {}) {
VLOG(6) << "Generating GradNode Creation codes"; VLOG(6) << "Generating GradNode Creation codes";
const std::string& op_type = fwd_info.GetOpType(); const std::string& op_type = fwd_info.GetOpType();
...@@ -1045,8 +1061,10 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1045,8 +1061,10 @@ static std::string GenerateGradNodeCreationContent(
} else { } else {
// In inplace op, the case where output is duplicable is not considered. // In inplace op, the case where output is duplicable is not considered.
// Replace output directly with input in inplace op. // Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(output_name)) { if (!forward_inplace_map.empty() &&
auto inplace_input_name = LegalizeVarName(inplace_map[output_name]); forward_inplace_map.count(output_name)) {
auto inplace_input_name =
LegalizeVarName(forward_inplace_map[output_name]);
const std::string& inplace_input_autograd_name = const std::string& inplace_input_autograd_name =
"p_autograd_" + inplace_input_name; "p_autograd_" + inplace_input_name;
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
...@@ -1103,12 +1121,12 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1103,12 +1121,12 @@ static std::string GenerateGradNodeCreationContent(
// check inplace input to avoid inplace operations on leaf nodes with // check inplace input to avoid inplace operations on leaf nodes with
// stop_gradient=False. // stop_gradient=False.
std::string check_inplace_str = ""; std::string check_inplace_str = "";
if (!inplace_map.empty()) { if (!forward_inplace_map.empty()) {
const char* CHECKING_INPLACE_TEMPLATE = const char* CHECKING_INPLACE_TEMPLATE =
" // Check Inplace\n" " // Check Inplace\n"
" egr::EagerUtils::CheckInplace(%s, p_autograd_%s, " " egr::EagerUtils::CheckInplace(%s, p_autograd_%s, "
"require_any_grad);\n"; "require_any_grad);\n";
for (auto& inplace_pair : inplace_map) { for (auto& inplace_pair : forward_inplace_map) {
std::string inplace_name = LegalizeVarName(inplace_pair.second); std::string inplace_name = LegalizeVarName(inplace_pair.second);
check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE, check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE,
inplace_name, inplace_name); inplace_name, inplace_name);
...@@ -1161,8 +1179,9 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1161,8 +1179,9 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_TENSOR_WRAPPER_TEMPLATE = const char* SET_TENSOR_WRAPPER_TEMPLATE =
" grad_node->SetTensorWrapper%s(%s);\n"; " grad_node->SetTensorWrapper%s(%s);\n";
// Replace output directly with input in inplace op. // Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) { if (!forward_inplace_map.empty() &&
auto inplace_input_name = inplace_map[tensor_wrapper_name]; forward_inplace_map.count(tensor_wrapper_name)) {
auto inplace_input_name = forward_inplace_map[tensor_wrapper_name];
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name), SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(inplace_input_name)); LegalizeVarName(inplace_input_name));
...@@ -1213,8 +1232,9 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1213,8 +1232,9 @@ static std::string GenerateGradNodeCreationContent(
for (const proto::OpProto::Var& output : out_vars) { for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name(); const std::string& output_name = output.name();
// Replace output directly with input in inplace op. // Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(output_name)) { if (!forward_inplace_map.empty() &&
auto inplace_input_name = inplace_map[output_name]; forward_inplace_map.count(output_name)) {
auto inplace_input_name = forward_inplace_map[output_name];
const std::string& inplace_input_autograd_name = const std::string& inplace_input_autograd_name =
"p_autograd_" + LegalizeVarName(inplace_input_name); "p_autograd_" + LegalizeVarName(inplace_input_name);
size_t output_position = fwd_outputs_name_pos_map.at(output_name); size_t output_position = fwd_outputs_name_pos_map.at(output_name);
...@@ -1345,7 +1365,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1345,7 +1365,7 @@ static std::string GenerateGradNodeCreationContent(
static std::pair<std::string, std::string> GenerateForwardFunctionContents( static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const ForwardGenerationInfo& fwd_info, const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info, const GradNodeGenerationInfo& bwd_info,
std::map<std::string, std::string> inplace_map = {}) { std::map<std::string, std::string> forward_inplace_map = {}) {
/* --- Process Forward Info ---*/ /* --- Process Forward Info ---*/
const std::string& op_type = fwd_info.GetOpType(); const std::string& op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map = const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
...@@ -1434,8 +1454,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1434,8 +1454,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// inplace tensor can't be const // inplace tensor can't be const
const char* FWD_INS_ARG_TEMPLATE; const char* FWD_INS_ARG_TEMPLATE;
bool flag_find_input_name = false; bool flag_find_input_name = false;
if (!inplace_map.empty()) { if (!forward_inplace_map.empty()) {
for (auto& inplace_pair : inplace_map) { for (auto& inplace_pair : forward_inplace_map) {
if (inplace_pair.second == input_name) { if (inplace_pair.second == input_name) {
flag_find_input_name = true; flag_find_input_name = true;
FWD_INS_ARG_TEMPLATE = "paddle::experimental::Tensor& %s"; FWD_INS_ARG_TEMPLATE = "paddle::experimental::Tensor& %s";
...@@ -1605,15 +1625,16 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1605,15 +1625,16 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
} }
core_ops_args_info[op_type].push_back(output_name); core_ops_args_info[op_type].push_back(output_name);
} else if (!inplace_map.empty() && inplace_map.count(output_name)) { } else if (!forward_inplace_map.empty() &&
forward_inplace_map.count(output_name)) {
// In inplace op, replace the output with the input directly. // In inplace op, replace the output with the input directly.
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
inplace_map[output_name], "", forward_inplace_map[output_name], "",
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Inplace op %s has no input corresponding to output %s.", op_type, "Inplace op %s has no input corresponding to output %s.", op_type,
output_name)); output_name));
const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", ins[\"%s\"] },"; const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", ins[\"%s\"] },";
auto inplace_input_name = inplace_map[output_name]; auto inplace_input_name = forward_inplace_map[output_name];
outs_contents_str += paddle::string::Sprintf( outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, inplace_input_name); FWD_OUTS_CONTENT_TEMPLATE, output_name, inplace_input_name);
...@@ -1651,7 +1672,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1651,7 +1672,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (inplace_mapping_str.size() > 0) if (inplace_mapping_str.size() > 0)
inplace_mapping_str.pop_back(); // Remove trailing "," inplace_mapping_str.pop_back(); // Remove trailing ","
if ((op_type != "cast") && (inplace_map.empty())) { if ((op_type != "cast") && (forward_inplace_map.empty())) {
VLOG(6) << "Generating Dygraph Forward AMP"; VLOG(6) << "Generating Dygraph Forward AMP";
const char* AMP_LOGIC_CONTEXT = const char* AMP_LOGIC_CONTEXT =
" if (egr::Controller::Instance().GetAMPLevel() != " " if (egr::Controller::Instance().GetAMPLevel() != "
...@@ -1743,7 +1764,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1743,7 +1764,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG(6) << "Generated Outs Map"; VLOG(6) << "Generated Outs Map";
// [Generation] Apply View Strategy (Tensor) // [Generation] Apply View Strategy (Tensor)
if (inplace_map.empty() && view_op_map.count(op_type)) { if (forward_inplace_map.empty() && view_op_map.count(op_type)) {
const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT = const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT =
" if (ins.count(\"%s\") && outs.count(\"%s\")) {\n" " if (ins.count(\"%s\") && outs.count(\"%s\")) {\n"
" egr::EagerUtils::HandleViewBetweenInputAndOutput(ins[\"%s\"][0], " " egr::EagerUtils::HandleViewBetweenInputAndOutput(ins[\"%s\"][0], "
...@@ -1852,10 +1873,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1852,10 +1873,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
output_varname, output_var_args_name); output_varname, output_var_args_name);
} }
} else { } else {
if (!inplace_map.empty() && inplace_map.count(output_name)) { if (!forward_inplace_map.empty() &&
forward_inplace_map.count(output_name)) {
// Modify meta info of inplace tensor. // Modify meta info of inplace tensor.
// Bump inplace version of inplace tensor. // Bump inplace version of inplace tensor.
auto inplace_input_name = inplace_map[output_name]; auto inplace_input_name = forward_inplace_map[output_name];
const char* FWD_OUT_TENSOR_TEMPLATE = const char* FWD_OUT_TENSOR_TEMPLATE =
" egr::EagerUtils::GetOutput(outs[\"%s\"][0], &%s);\n" " egr::EagerUtils::GetOutput(outs[\"%s\"][0], &%s);\n"
" %s.bump_inplace_version();\n" " %s.bump_inplace_version();\n"
...@@ -1878,10 +1900,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1878,10 +1900,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
return_types[return_position] = "paddle::experimental::Tensor"; return_types[return_position] = "paddle::experimental::Tensor";
} }
if (!inplace_map.empty() && inplace_map.count(output_name)) { if (!forward_inplace_map.empty() &&
forward_inplace_map.count(output_name)) {
// Replace output directly with input in inplace op. // Replace output directly with input in inplace op.
return_contents[return_position] = return_contents[return_position] =
LegalizeVarName(inplace_map[output_name]); LegalizeVarName(forward_inplace_map[output_name]);
} else { } else {
return_contents[return_position] = output_varname; return_contents[return_position] = output_varname;
} }
...@@ -1903,7 +1926,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1903,7 +1926,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// If GradNode needs to be generated, pass `trace_op_body_str` // If GradNode needs to be generated, pass `trace_op_body_str`
// into `GenerateGradNodeCreationContent`. // into `GenerateGradNodeCreationContent`.
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( std::string grad_node_creation_body_str = GenerateGradNodeCreationContent(
fwd_info, bwd_info, trace_op_body_str, inplace_map); fwd_info, bwd_info, trace_op_body_str, forward_inplace_map);
generated_function_body += grad_node_creation_body_str; generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n"; generated_function_body += "\n";
...@@ -1960,7 +1983,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1960,7 +1983,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Full Function // [Generation] Get Full Function
std::string function_name; std::string function_name;
if (inplace_map.empty()) { if (forward_inplace_map.empty()) {
function_name = op_type + "_dygraph_function"; function_name = op_type + "_dygraph_function";
} else { } else {
// change function_name for inplace op. // change function_name for inplace op.
...@@ -2013,6 +2036,7 @@ static std::string GenerateSingleOpBase( ...@@ -2013,6 +2036,7 @@ static std::string GenerateSingleOpBase(
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>& std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs, grad_outs,
const paddle::framework::AttributeMap& grad_attrs, const paddle::framework::AttributeMap& grad_attrs,
const std::unordered_map<std::string, std::string>& backward_inplace_map,
bool is_op_base_per_duplicable_input, size_t* outs_size) { bool is_op_base_per_duplicable_input, size_t* outs_size) {
std::string generated_grad_function_body = ""; std::string generated_grad_function_body = "";
...@@ -2029,6 +2053,23 @@ static std::string GenerateSingleOpBase( ...@@ -2029,6 +2053,23 @@ static std::string GenerateSingleOpBase(
for (const auto& in : in_vars) { for (const auto& in : in_vars) {
if (in.duplicable()) duplicable_input_name_set.insert(in.name()); if (in.duplicable()) duplicable_input_name_set.insert(in.name());
} }
const char* CHECK_BACKWARD_INPLACE_TEMPLATE =
" // Check backward inplace info\n"
" bool %s = false;\n"
" %s\n"
" if (%s.initialized()) {\n"
" VLOG(10) << %s.name() << \"(%s) use_count: \" << "
"%s.impl().use_count();\n"
" if (%s.impl().use_count() == 1 || (%s.impl().use_count() == 2 && "
"%s.impl().get() == %s.impl().get())) {\n"
" %s = true;\n"
" }\n"
" }\n";
const std::string& can_be_inplaced_name =
"can_be_inplaced" + std::to_string(*outs_size);
const std::string& bwd_inplace_input_name =
"backward_inplace_tensor" + std::to_string(*outs_size);
bool process_backward_inplace = false;
std::string ins_contents_str = ""; std::string ins_contents_str = "";
for (auto iter : grad_ins) { for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first; const std::string& grad_input_name = iter.first;
...@@ -2051,7 +2092,26 @@ static std::string GenerateSingleOpBase( ...@@ -2051,7 +2092,26 @@ static std::string GenerateSingleOpBase(
ins_contents_str += ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
grad_input_name, struct_fwd_input_name); grad_input_name, struct_fwd_input_name);
if (!backward_inplace_map.empty() &&
backward_inplace_map.count(grad_input_name)) {
process_backward_inplace = true;
const char* GRAD_INS_FWD_TENSOR_WRAPPER_TEMPLATE =
"auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s);";
std::string tensor_wrapper_str = paddle::string::Sprintf(
GRAD_INS_FWD_TENSOR_WRAPPER_TEMPLATE, bwd_inplace_input_name,
struct_fwd_input_name);
const char* GRAD_INS_FWD_TENSOR_TEMPLATE =
"(&this->%s)->get_intermidiate_tensor()";
std::string tensor_wrapper_intermidiate_tensor_str =
paddle::string::Sprintf(GRAD_INS_FWD_TENSOR_TEMPLATE,
struct_fwd_input_name);
generated_grad_function_body += paddle::string::Sprintf(
CHECK_BACKWARD_INPLACE_TEMPLATE, can_be_inplaced_name,
tensor_wrapper_str, bwd_inplace_input_name, bwd_inplace_input_name,
grad_input_name, bwd_inplace_input_name, bwd_inplace_input_name,
bwd_inplace_input_name, bwd_inplace_input_name,
tensor_wrapper_intermidiate_tensor_str, can_be_inplaced_name);
}
} else if (grad_ins_grad_slotname_map.count(grad_input_name)) { } else if (grad_ins_grad_slotname_map.count(grad_input_name)) {
// Fwd Tensor's Grad // Fwd Tensor's Grad
size_t fwd_output_position = fwd_outputs_name_pos_map.at( size_t fwd_output_position = fwd_outputs_name_pos_map.at(
...@@ -2060,7 +2120,24 @@ static std::string GenerateSingleOpBase( ...@@ -2060,7 +2120,24 @@ static std::string GenerateSingleOpBase(
"{ \"%s\", egr::EagerUtils::TrySyncToVars(hooked_grads[%d]) },"; "{ \"%s\", egr::EagerUtils::TrySyncToVars(hooked_grads[%d]) },";
ins_contents_str += paddle::string::Sprintf( ins_contents_str += paddle::string::Sprintf(
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position); GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
if (!backward_inplace_map.empty() &&
backward_inplace_map.count(grad_input_name)) {
process_backward_inplace = true;
const char* GRAD_INS_HOOKED_GRAD_TEMPLATE =
"auto& %s = hooked_grads[%d][0];";
std::string hooked_grads_tensor_str = paddle::string::Sprintf(
GRAD_INS_HOOKED_GRAD_TEMPLATE, bwd_inplace_input_name,
fwd_output_position);
const char* GRAD_INS_GRAD_TENSOR_TEMPLATE = "grads[%d][0]";
std::string grads_tensor_str = paddle::string::Sprintf(
GRAD_INS_GRAD_TENSOR_TEMPLATE, fwd_output_position);
generated_grad_function_body += paddle::string::Sprintf(
CHECK_BACKWARD_INPLACE_TEMPLATE, can_be_inplaced_name,
hooked_grads_tensor_str, bwd_inplace_input_name,
bwd_inplace_input_name, grad_input_name, bwd_inplace_input_name,
bwd_inplace_input_name, bwd_inplace_input_name,
bwd_inplace_input_name, grads_tensor_str, can_be_inplaced_name);
}
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names." "Detected mismatched slot names."
...@@ -2245,6 +2322,27 @@ static std::string GenerateSingleOpBase( ...@@ -2245,6 +2322,27 @@ static std::string GenerateSingleOpBase(
VLOG(6) << "Generated Outs Map"; VLOG(6) << "Generated Outs Map";
// [Generation] Process Backward Inplace
if (process_backward_inplace) {
const char* HANDLE_BACKWARD_INPLACE_BETWEEN_INPUT_AND_OUTPUT =
" if (%s && %s.count(\"%s\") && %s.count(\"%s\")) {\n"
" egr::EagerUtils::HandleViewBetweenInputAndOutput(%s[\"%s\"][0], "
"%s[\"%s\"][0]);\n"
" };\n";
std::string backward_inplace_map_str = "";
for (auto iter : backward_inplace_map) {
std::string backward_inplace_input_name = iter.first;
std::string backward_inplace_output_name = iter.second;
backward_inplace_map_str += paddle::string::Sprintf(
HANDLE_BACKWARD_INPLACE_BETWEEN_INPUT_AND_OUTPUT,
can_be_inplaced_name, ins_name, backward_inplace_input_name,
outs_name, backward_inplace_output_name, ins_name,
backward_inplace_input_name, outs_name, backward_inplace_output_name);
}
generated_grad_function_body += backward_inplace_map_str;
VLOG(6) << "Process Backward Inplace";
}
// [Generation] Get Attrs Map // [Generation] Get Attrs Map
const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n"; const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n";
std::string grad_attrs_str = std::string grad_attrs_str =
...@@ -2428,13 +2526,15 @@ static std::string GenerateGradNodeCCContents( ...@@ -2428,13 +2526,15 @@ static std::string GenerateGradNodeCCContents(
const auto& grad_ins = op_base_info.GetGradIns(); const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts(); const auto& grad_outs = op_base_info.GetGradOuts();
const auto& grad_attrs = op_base_info.GetGradAttrs(); const auto& grad_attrs = op_base_info.GetGradAttrs();
const auto& backward_inplace_map = op_base_info.GetBackwardInplaceMap();
const std::string& op_base_type = op_base_info.GetOpBaseType(); const std::string& op_base_type = op_base_info.GetOpBaseType();
generated_grad_function_body += GenerateSingleOpBase( generated_grad_function_body += GenerateSingleOpBase(
fwd_op_type, op_base_type, fwd_inputs_name_pos_map, fwd_op_type, op_base_type, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, in_vars, grad_ins_fwd_slotname_map, fwd_outputs_name_pos_map, in_vars, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs, grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs,
grad_attrs, is_op_base_per_duplicable_input, &outs_size); grad_attrs, backward_inplace_map, is_op_base_per_duplicable_input,
&outs_size);
} }
if (is_op_base_per_duplicable_input) { if (is_op_base_per_duplicable_input) {
...@@ -2847,19 +2947,20 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -2847,19 +2947,20 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
auto& infer_inplace = auto& infer_inplace =
paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_; paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;
std::map<std::string, std::string> inplace_map; std::map<std::string, std::string> forward_inplace_map;
// Inplace Function Generator. // Inplace Function Generator.
// `sum` op has duplicate input. Don't consider adding inplace strategy // `sum` op has duplicate input. Don't consider adding inplace strategy
// for `sum` in temporary. // for `sum` in temporary.
if (infer_inplace && !special_inplace_op_set.count(op_type)) { if (infer_inplace && !special_inplace_op_set.count(op_type)) {
auto in_to_outs = infer_inplace(true); auto in_to_outs = infer_inplace(true);
for (auto& inplace_pair : in_to_outs) { for (auto& inplace_pair : in_to_outs) {
inplace_map[inplace_pair.second] = inplace_pair.first; forward_inplace_map[inplace_pair.second] = inplace_pair.first;
} }
VLOG(6) << "-------- GenerateInplaceForwardFunctionContents -------"; VLOG(6) << "-------- GenerateInplaceForwardFunctionContents -------";
std::pair<std::string, std::string> inplace_body_and_declaration = std::pair<std::string, std::string> inplace_body_and_declaration =
GenerateForwardFunctionContents(fwd_info, bwd_info, inplace_map); GenerateForwardFunctionContents(fwd_info, bwd_info,
forward_inplace_map);
fwd_function_str += inplace_body_and_declaration.first + "\n"; fwd_function_str += inplace_body_and_declaration.first + "\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册