未验证 提交 ed2886de 编写于 作者: P pangyoki 提交者: GitHub

support backward inplace in eager fluid dygraph mode (#43054)

* support backward inplace in eager fluid mode

* fix

* fix

* optimize format

* little change
上级 3d56d419
......@@ -231,6 +231,15 @@ class GradNodeGenerationInfo {
return &no_need_buffer_ins_;
}
const std::unordered_map<std::string, std::string>& GetBackwardInplaceMap()
const {
return backward_inplace_map_;
}
std::unordered_map<std::string, std::string>*
GetMutableBackwardInplaceMap() {
return &backward_inplace_map_;
}
private:
std::string op_base_type_;
std::map<std::string, std::string> grad_outs_slotname_map_;
......@@ -244,6 +253,7 @@ class GradNodeGenerationInfo {
grad_outs_;
paddle::framework::AttributeMap grad_attrs_;
std::unordered_set<std::string> no_need_buffer_ins_;
std::unordered_map<std::string, std::string> backward_inplace_map_;
};
public:
......@@ -979,6 +989,12 @@ static bool CollectGradInformationFromOpInfo(
*(*op_base_infos)[index].GetMutableNoNeedBufferInputs() =
inferer(g_ins, g_outs, *op_base_grad_attrs);
}
auto& infer_backward_inplace = op_base.Info().infer_inplace_;
if (infer_backward_inplace) {
*(*op_base_infos)[index].GetMutableBackwardInplaceMap() =
infer_backward_inplace(true);
}
}
/* ------ Slot Name Matching ---- */
......@@ -1005,7 +1021,7 @@ static std::string GenerateGradNodeCreationContent(
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info,
const std::string& trace_op_body_str,
std::map<std::string, std::string> inplace_map = {}) {
std::map<std::string, std::string> forward_inplace_map = {}) {
VLOG(6) << "Generating GradNode Creation codes";
const std::string& op_type = fwd_info.GetOpType();
......@@ -1045,8 +1061,10 @@ static std::string GenerateGradNodeCreationContent(
} else {
// In inplace op, the case where output is duplicable is not considered.
// Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(output_name)) {
auto inplace_input_name = LegalizeVarName(inplace_map[output_name]);
if (!forward_inplace_map.empty() &&
forward_inplace_map.count(output_name)) {
auto inplace_input_name =
LegalizeVarName(forward_inplace_map[output_name]);
const std::string& inplace_input_autograd_name =
"p_autograd_" + inplace_input_name;
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
......@@ -1103,12 +1121,12 @@ static std::string GenerateGradNodeCreationContent(
// check inplace input to avoid inplace operations on leaf nodes with
// stop_gradient=False.
std::string check_inplace_str = "";
if (!inplace_map.empty()) {
if (!forward_inplace_map.empty()) {
const char* CHECKING_INPLACE_TEMPLATE =
" // Check Inplace\n"
" egr::EagerUtils::CheckInplace(%s, p_autograd_%s, "
"require_any_grad);\n";
for (auto& inplace_pair : inplace_map) {
for (auto& inplace_pair : forward_inplace_map) {
std::string inplace_name = LegalizeVarName(inplace_pair.second);
check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE,
inplace_name, inplace_name);
......@@ -1161,8 +1179,9 @@ static std::string GenerateGradNodeCreationContent(
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" grad_node->SetTensorWrapper%s(%s);\n";
// Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(tensor_wrapper_name)) {
auto inplace_input_name = inplace_map[tensor_wrapper_name];
if (!forward_inplace_map.empty() &&
forward_inplace_map.count(tensor_wrapper_name)) {
auto inplace_input_name = forward_inplace_map[tensor_wrapper_name];
grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, LegalizeVarName(tensor_wrapper_name),
LegalizeVarName(inplace_input_name));
......@@ -1213,8 +1232,9 @@ static std::string GenerateGradNodeCreationContent(
for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name();
// Replace output directly with input in inplace op.
if (!inplace_map.empty() && inplace_map.count(output_name)) {
auto inplace_input_name = inplace_map[output_name];
if (!forward_inplace_map.empty() &&
forward_inplace_map.count(output_name)) {
auto inplace_input_name = forward_inplace_map[output_name];
const std::string& inplace_input_autograd_name =
"p_autograd_" + LegalizeVarName(inplace_input_name);
size_t output_position = fwd_outputs_name_pos_map.at(output_name);
......@@ -1345,7 +1365,7 @@ static std::string GenerateGradNodeCreationContent(
static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info,
std::map<std::string, std::string> inplace_map = {}) {
std::map<std::string, std::string> forward_inplace_map = {}) {
/* --- Process Forward Info ---*/
const std::string& op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
......@@ -1434,8 +1454,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// inplace tensor can't be const
const char* FWD_INS_ARG_TEMPLATE;
bool flag_find_input_name = false;
if (!inplace_map.empty()) {
for (auto& inplace_pair : inplace_map) {
if (!forward_inplace_map.empty()) {
for (auto& inplace_pair : forward_inplace_map) {
if (inplace_pair.second == input_name) {
flag_find_input_name = true;
FWD_INS_ARG_TEMPLATE = "paddle::experimental::Tensor& %s";
......@@ -1605,15 +1625,16 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
}
core_ops_args_info[op_type].push_back(output_name);
} else if (!inplace_map.empty() && inplace_map.count(output_name)) {
} else if (!forward_inplace_map.empty() &&
forward_inplace_map.count(output_name)) {
// In inplace op, replace the output with the input directly.
PADDLE_ENFORCE_NE(
inplace_map[output_name], "",
forward_inplace_map[output_name], "",
paddle::platform::errors::InvalidArgument(
"Inplace op %s has no input corresponding to output %s.", op_type,
output_name));
const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", ins[\"%s\"] },";
auto inplace_input_name = inplace_map[output_name];
auto inplace_input_name = forward_inplace_map[output_name];
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, inplace_input_name);
......@@ -1651,7 +1672,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (inplace_mapping_str.size() > 0)
inplace_mapping_str.pop_back(); // Remove trailing ","
if ((op_type != "cast") && (inplace_map.empty())) {
if ((op_type != "cast") && (forward_inplace_map.empty())) {
VLOG(6) << "Generating Dygraph Forward AMP";
const char* AMP_LOGIC_CONTEXT =
" if (egr::Controller::Instance().GetAMPLevel() != "
......@@ -1743,7 +1764,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG(6) << "Generated Outs Map";
// [Generation] Apply View Strategy (Tensor)
if (inplace_map.empty() && view_op_map.count(op_type)) {
if (forward_inplace_map.empty() && view_op_map.count(op_type)) {
const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT =
" if (ins.count(\"%s\") && outs.count(\"%s\")) {\n"
" egr::EagerUtils::HandleViewBetweenInputAndOutput(ins[\"%s\"][0], "
......@@ -1852,10 +1873,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
output_varname, output_var_args_name);
}
} else {
if (!inplace_map.empty() && inplace_map.count(output_name)) {
if (!forward_inplace_map.empty() &&
forward_inplace_map.count(output_name)) {
// Modify meta info of inplace tensor.
// Bump inplace version of inplace tensor.
auto inplace_input_name = inplace_map[output_name];
auto inplace_input_name = forward_inplace_map[output_name];
const char* FWD_OUT_TENSOR_TEMPLATE =
" egr::EagerUtils::GetOutput(outs[\"%s\"][0], &%s);\n"
" %s.bump_inplace_version();\n"
......@@ -1878,10 +1900,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
return_types[return_position] = "paddle::experimental::Tensor";
}
if (!inplace_map.empty() && inplace_map.count(output_name)) {
if (!forward_inplace_map.empty() &&
forward_inplace_map.count(output_name)) {
// Replace output directly with input in inplace op.
return_contents[return_position] =
LegalizeVarName(inplace_map[output_name]);
LegalizeVarName(forward_inplace_map[output_name]);
} else {
return_contents[return_position] = output_varname;
}
......@@ -1903,7 +1926,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// If GradNode needs to be generated, pass `trace_op_body_str`
// into `GenerateGradNodeCreationContent`.
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent(
fwd_info, bwd_info, trace_op_body_str, inplace_map);
fwd_info, bwd_info, trace_op_body_str, forward_inplace_map);
generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n";
......@@ -1960,7 +1983,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Full Function
std::string function_name;
if (inplace_map.empty()) {
if (forward_inplace_map.empty()) {
function_name = op_type + "_dygraph_function";
} else {
// change function_name for inplace op.
......@@ -2013,6 +2036,7 @@ static std::string GenerateSingleOpBase(
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const paddle::framework::AttributeMap& grad_attrs,
const std::unordered_map<std::string, std::string>& backward_inplace_map,
bool is_op_base_per_duplicable_input, size_t* outs_size) {
std::string generated_grad_function_body = "";
......@@ -2029,6 +2053,23 @@ static std::string GenerateSingleOpBase(
for (const auto& in : in_vars) {
if (in.duplicable()) duplicable_input_name_set.insert(in.name());
}
const char* CHECK_BACKWARD_INPLACE_TEMPLATE =
" // Check backward inplace info\n"
" bool %s = false;\n"
" %s\n"
" if (%s.initialized()) {\n"
" VLOG(10) << %s.name() << \"(%s) use_count: \" << "
"%s.impl().use_count();\n"
" if (%s.impl().use_count() == 1 || (%s.impl().use_count() == 2 && "
"%s.impl().get() == %s.impl().get())) {\n"
" %s = true;\n"
" }\n"
" }\n";
const std::string& can_be_inplaced_name =
"can_be_inplaced" + std::to_string(*outs_size);
const std::string& bwd_inplace_input_name =
"backward_inplace_tensor" + std::to_string(*outs_size);
bool process_backward_inplace = false;
std::string ins_contents_str = "";
for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first;
......@@ -2051,7 +2092,26 @@ static std::string GenerateSingleOpBase(
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
grad_input_name, struct_fwd_input_name);
if (!backward_inplace_map.empty() &&
backward_inplace_map.count(grad_input_name)) {
process_backward_inplace = true;
const char* GRAD_INS_FWD_TENSOR_WRAPPER_TEMPLATE =
"auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s);";
std::string tensor_wrapper_str = paddle::string::Sprintf(
GRAD_INS_FWD_TENSOR_WRAPPER_TEMPLATE, bwd_inplace_input_name,
struct_fwd_input_name);
const char* GRAD_INS_FWD_TENSOR_TEMPLATE =
"(&this->%s)->get_intermidiate_tensor()";
std::string tensor_wrapper_intermidiate_tensor_str =
paddle::string::Sprintf(GRAD_INS_FWD_TENSOR_TEMPLATE,
struct_fwd_input_name);
generated_grad_function_body += paddle::string::Sprintf(
CHECK_BACKWARD_INPLACE_TEMPLATE, can_be_inplaced_name,
tensor_wrapper_str, bwd_inplace_input_name, bwd_inplace_input_name,
grad_input_name, bwd_inplace_input_name, bwd_inplace_input_name,
bwd_inplace_input_name, bwd_inplace_input_name,
tensor_wrapper_intermidiate_tensor_str, can_be_inplaced_name);
}
} else if (grad_ins_grad_slotname_map.count(grad_input_name)) {
// Fwd Tensor's Grad
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
......@@ -2060,7 +2120,24 @@ static std::string GenerateSingleOpBase(
"{ \"%s\", egr::EagerUtils::TrySyncToVars(hooked_grads[%d]) },";
ins_contents_str += paddle::string::Sprintf(
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
if (!backward_inplace_map.empty() &&
backward_inplace_map.count(grad_input_name)) {
process_backward_inplace = true;
const char* GRAD_INS_HOOKED_GRAD_TEMPLATE =
"auto& %s = hooked_grads[%d][0];";
std::string hooked_grads_tensor_str = paddle::string::Sprintf(
GRAD_INS_HOOKED_GRAD_TEMPLATE, bwd_inplace_input_name,
fwd_output_position);
const char* GRAD_INS_GRAD_TENSOR_TEMPLATE = "grads[%d][0]";
std::string grads_tensor_str = paddle::string::Sprintf(
GRAD_INS_GRAD_TENSOR_TEMPLATE, fwd_output_position);
generated_grad_function_body += paddle::string::Sprintf(
CHECK_BACKWARD_INPLACE_TEMPLATE, can_be_inplaced_name,
hooked_grads_tensor_str, bwd_inplace_input_name,
bwd_inplace_input_name, grad_input_name, bwd_inplace_input_name,
bwd_inplace_input_name, bwd_inplace_input_name,
bwd_inplace_input_name, grads_tensor_str, can_be_inplaced_name);
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
......@@ -2245,6 +2322,27 @@ static std::string GenerateSingleOpBase(
VLOG(6) << "Generated Outs Map";
// [Generation] Process Backward Inplace
if (process_backward_inplace) {
const char* HANDLE_BACKWARD_INPLACE_BETWEEN_INPUT_AND_OUTPUT =
" if (%s && %s.count(\"%s\") && %s.count(\"%s\")) {\n"
" egr::EagerUtils::HandleViewBetweenInputAndOutput(%s[\"%s\"][0], "
"%s[\"%s\"][0]);\n"
" };\n";
std::string backward_inplace_map_str = "";
for (auto iter : backward_inplace_map) {
std::string backward_inplace_input_name = iter.first;
std::string backward_inplace_output_name = iter.second;
backward_inplace_map_str += paddle::string::Sprintf(
HANDLE_BACKWARD_INPLACE_BETWEEN_INPUT_AND_OUTPUT,
can_be_inplaced_name, ins_name, backward_inplace_input_name,
outs_name, backward_inplace_output_name, ins_name,
backward_inplace_input_name, outs_name, backward_inplace_output_name);
}
generated_grad_function_body += backward_inplace_map_str;
VLOG(6) << "Process Backward Inplace";
}
// [Generation] Get Attrs Map
const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n";
std::string grad_attrs_str =
......@@ -2428,13 +2526,15 @@ static std::string GenerateGradNodeCCContents(
const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts();
const auto& grad_attrs = op_base_info.GetGradAttrs();
const auto& backward_inplace_map = op_base_info.GetBackwardInplaceMap();
const std::string& op_base_type = op_base_info.GetOpBaseType();
generated_grad_function_body += GenerateSingleOpBase(
fwd_op_type, op_base_type, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, in_vars, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs,
grad_attrs, is_op_base_per_duplicable_input, &outs_size);
grad_attrs, backward_inplace_map, is_op_base_per_duplicable_input,
&outs_size);
}
if (is_op_base_per_duplicable_input) {
......@@ -2847,19 +2947,20 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
auto& infer_inplace =
paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;
std::map<std::string, std::string> inplace_map;
std::map<std::string, std::string> forward_inplace_map;
// Inplace Function Generator.
// `sum` op has duplicate input. Don't consider adding inplace strategy
// for `sum` in temporary.
if (infer_inplace && !special_inplace_op_set.count(op_type)) {
auto in_to_outs = infer_inplace(true);
for (auto& inplace_pair : in_to_outs) {
inplace_map[inplace_pair.second] = inplace_pair.first;
forward_inplace_map[inplace_pair.second] = inplace_pair.first;
}
VLOG(6) << "-------- GenerateInplaceForwardFunctionContents -------";
std::pair<std::string, std::string> inplace_body_and_declaration =
GenerateForwardFunctionContents(fwd_info, bwd_info, inplace_map);
GenerateForwardFunctionContents(fwd_info, bwd_info,
forward_inplace_map);
fwd_function_str += inplace_body_and_declaration.first + "\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册