diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index 8b7ea547b2632a052d58d0c7679130dc925d5fcf..e16bcb187f85a7fc59fc06a0d8564841c384c7e2 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -21,7 +21,8 @@ import os ######################## ### Global Variables ### ######################## -ops_to_fill_zero_for_empty_grads = set(["split", "rnn"]) +ops_to_fill_zero_for_empty_grads = set( + ["split_grad", "rnn_grad", "matmul_double_grad"]) # For API dispatch used at python-level # { op_name : [arg_name, ...] } @@ -176,6 +177,11 @@ def TransformGradVarNameForDoubleGradGeneration(string): return string +def GetIndent(num): + tab = " " + return "".join([tab for i in range(num)]) + + ###################### ### Yaml Parsers ### ###################### diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 0f78763d6c959b1bfda61f80663334c680bf062c..fb86c5da6856c51c446455a4112f84e31c4d76a6 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -32,7 +32,7 @@ from codegen_utils import ParseYamlForward, ParseYamlBackward from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import ops_to_fill_zero_for_empty_grads from codegen_utils import TransformGradVarNameForDoubleGradGeneration -from codegen_utils import AssertMessage +from codegen_utils import AssertMessage, GetIndent ########### @@ -112,80 +112,81 @@ ATTRIBUTE_MEMBER_TEMPLATE = \ NODE_DECLARATION_TEMPLATE = \ """ - class {} : public egr::GradNodeBase {{ - public: - {}() : egr::GradNodeBase() {{}} - {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : - egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}} - ~{}() override = default; - - virtual std::vector> operator()( - std::vector>& grads, bool create_graph = false) override; - std::string name() override {{ return \" {} \"; }} - - void ClearTensorWrappers() override {{ - {} - is_tensor_wrappers_cleared = true; - }} - - // SetTensorWrapperX, SetTensorWrapperY, ... +class {} : public egr::GradNodeBase {{ + public: + {}() : egr::GradNodeBase() {{}} + {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : + egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}} + ~{}() override = default; + + virtual std::vector> operator()( + std::vector>& grads, bool create_graph = false) override; + std::string name() override {{ return \" {} \"; }} + + void ClearTensorWrappers() override {{ {} - // SetAttributes - {} - - bool IsTensorWrappersCleared() override {{ - return is_tensor_wrappers_cleared; - }} - private: - // TensorWrappers - {} - - bool is_tensor_wrappers_cleared = false; - - // Attributes - {} - }}; + is_tensor_wrappers_cleared = true; + }} + + // SetTensorWrapperX, SetTensorWrapperY, ... + {} + // SetAttributes + {} + + bool IsTensorWrappersCleared() override {{ + return is_tensor_wrappers_cleared; + }} + private: + // TensorWrappers + {} + + bool is_tensor_wrappers_cleared = false; + + // Attributes + {} +}}; """ GRAD_FUNCTION_TEMPLATE = \ """ - std::vector> {}::operator()(std::vector>& grads, bool create_graph) {{ - // Fill Zero For GradIn Tensors - {} +std::vector> {}::operator()(std::vector>& grads, bool create_graph) {{ + // Fill Zero For GradIn Tensors +{} - // Apply Gradient Hooks - auto hooked_grads = ApplyGradientHooks(grads); - - // Collect GradIn Tensors, Attrs and Recovered TensorWrappers - {} + // Apply Gradient Hooks + auto hooked_grads = ApplyGradientHooks(grads); + + // Collect GradIn Tensors, Attrs and Recovered TensorWrappers +{} - // Call grad_api function - VLOG(3) << \"Final State Running: \" << \"{}\"; - {} + // Call grad_api function + VLOG(3) << \"Final State Running: \" << \"{}\"; +{} - // Get Output - {} + // Get Output +{} - // Get GradIn autograd_meta - {} + // Get GradIn autograd_meta +{} - // Get GradOut autograd_meta - {} - - // Compute Require Grad - {} - - // Create Grad Node - {} + // Get GradOut autograd_meta +{} + + // Compute Require Grad +{} + + // Create Grad Node +{} - // Return - {} - }} + // Return +{} + +}} """ FORWARD_FUNCTION_TEMPLATE = \ """ - {} {}({}) {{ +{} {}({}) {{ // Dygraph Record Event {} // AMP Logic @@ -208,33 +209,33 @@ FORWARD_FUNCTION_TEMPLATE = \ // Node Creation {} - // Returns - return {}; - }} + // Returns + return {}; +}} """ FORWARD_BODY_TEMPLATE = \ """ - if(require_any_grad) {{ + if(require_any_grad) {{ {} - egr::EagerUtils::PassStopGradient({}); - - // Node Construction + egr::EagerUtils::PassStopGradient({}); + + // Node Construction {} - // SetAttributes + // SetAttributes {} - // SetTensorWrappers + // SetTensorWrappers {} - // SetGradOutMeta & SetEdges + // SetGradOutMeta & SetEdges {} {} - // SetOutRank & SetHistory & SetGradInMeta & RetainGrad + // SetOutRank & SetHistory & SetGradInMeta & RetainGrad {} {} {} {} - }} + }} """ NAMESPACE_WRAPPER_TEMPLATE = \ @@ -318,9 +319,9 @@ std::unordered_map> core_ops_final_state_r CORE_OPS_DECLARATION_TEMPLATE = \ """ - extern std::unordered_map> core_ops_final_state_args_info; - extern std::unordered_map> core_ops_final_state_args_type_info; - extern std::unordered_map> core_ops_final_state_returns_info; +extern std::unordered_map> core_ops_final_state_args_info; +extern std::unordered_map> core_ops_final_state_args_type_info; +extern std::unordered_map> core_ops_final_state_returns_info; """ @@ -352,6 +353,12 @@ AMP_LOGIC_TEMPLATE = \ }} """ +CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = \ +""" + paddle::optional {}_optional = paddle::none; + if({}.initialized()) {}_optional = paddle::make_optional({}); +""" + ####################### ## Generator Helpers ## @@ -678,12 +685,15 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_outputs = len(forward_inputs_position_map.keys()) grad_node_name = GetGradNodeName(forward_api_name) + + # Helper + indent = GetIndent(2) # NOTE(Aurelius74): DO NOT use make_shared here. Because some Node contains experimental::Scalar # which contains "complex128" as data. "complex128" is memory-aligned manually. But make_shared # request MEMALIGN for allocation (Maybe). # See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment # and https://github.com/MRtrix3/mrtrix3/issues/957 - node_construction_str = f" auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));" + node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));" # SetAttributes set_attributes_list = [] @@ -693,9 +703,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): for name, _, default_val_attr, _ in backward_attrs_list: if name in forward_attrs_name_set: - set_attributes = f" grad_node->SetAttribute{name}({name});" + set_attributes = f"{indent}grad_node->SetAttribute{name}({name});" else: - set_attributes = f" grad_node->SetAttribute{name}({default_val_attr});" + set_attributes = f"{indent}grad_node->SetAttribute{name}({default_val_attr});" set_attributes_list.append(set_attributes) set_attributes_str = "\n".join(set_attributes_list) @@ -708,9 +718,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): if is_fwd_input: if is_optional: - set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" + set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" + set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, true);" else: if num_fwd_outputs > 1: # Aligned with forward output position @@ -719,9 +729,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): fwd_output_pos = forward_outputs_position_map[name][1] if is_optional: - set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);" + set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);" else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);" + set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, false);" set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) @@ -732,11 +742,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): input_autograd_meta_name = GetAutoGradMetaName(name) is_optional = (name in self.optional_inputs) if is_optional: - set_grad_out_meta = f" if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" - set_edges = f" if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});" + set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" + set_edges = f"{indent}if({name}.get_ptr() != nullptr) grad_node->AddEdges({input_autograd_meta_name}, {pos});" else: - set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});" - set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});" + set_grad_out_meta = f"{indent}grad_node->SetGradOutMeta({name}, {pos});" + set_edges = f"{indent}grad_node->AddEdges({input_autograd_meta_name}, {pos});" set_grad_out_meta_list.append(set_grad_out_meta) set_edges_list.append(set_edges) @@ -751,11 +761,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): num_outputs = len(forward_outputs_position_map.keys()) for name, (_, pos) in forward_outputs_position_map.items(): output_autograd_meta_name = GetAutoGradMetaName(name) - set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" - set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" + set_out_rank = f"{indent}egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" + set_history = f"{indent}egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" - set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad({name});" - set_grad_in_meta = f" grad_node->SetGradInMeta({name}, {pos});" + set_retain_grad = f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});" + set_grad_in_meta = f"{indent}grad_node->SetGradInMeta({name}, {pos});" set_out_rank_list.append(set_out_rank) set_history_list.append(set_history) set_grad_in_meta_list.append(set_grad_in_meta) @@ -767,7 +777,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): set_retain_grad_str = "\n".join(set_retain_grad_list) node_event_name = forward_api_name + " node_creation" - node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" + node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" self.node_creation_str = FORWARD_BODY_TEMPLATE.format( node_creation_event_str, pass_stop_gradient_args_str, @@ -845,6 +855,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): optional_inputs = self.optional_inputs intermediate_outputs = self.intermediate_outputs inplace_map = self.inplace_map if is_inplaced else {} + indent = GetIndent(1) # Get Function Args num_inputs = len(forward_attrs_list) + len( @@ -918,7 +929,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): else: function_name = GetIntermediateAPIFunctionName(function_name) - forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" + forward_call_str = f"{indent}auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" num_outputs = len(forward_outputs_position_map.keys()) - len( intermediate_outputs) @@ -926,9 +937,9 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): get_outputs_str = "" for name, (rtype, pos) in forward_outputs_position_map.items(): if num_outputs == 1 and len(intermediate_outputs) == 0: - get_outputs_str += f"auto& {name} = api_result;\n" + get_outputs_str += f"{indent}auto& {name} = api_result;\n" else: - get_outputs_str += f"auto& {name} = std::get<{pos}>(api_result);\n" + get_outputs_str += f"{indent}auto& {name} = std::get<{pos}>(api_result);\n" # Get return type list & outputs returns_type_list = ["" for i in range(num_outputs)] @@ -961,12 +972,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): for name, (ttype, pos) in forward_inputs_position_map.items(): input_autograd_meta_name = GetAutoGradMetaName(name) if IsPlainTensorType(ttype): - input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" + input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" else: assert IsVectorTensorType(ttype) input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" - input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + input_autograd_meta = f"{indent}std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" + input_autograd_meta += f"{indent}std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" inputs_autograd_meta_list.append(input_autograd_meta) compute_require_grad_args_list.append(input_autograd_meta_name) @@ -981,19 +992,19 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) if num_fwd_outputs == 1: if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});" + output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});" else: assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + output_autograd_meta = f"{indent}std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n" + output_autograd_meta += f"{indent}std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" else: # Tuple api_result if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});" + output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});" else: assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + output_autograd_meta = f"{indent}std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n" + output_autograd_meta += f"{indent}std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" outputs_autograd_meta_list.append(output_autograd_meta) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) @@ -1012,7 +1023,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): self.GenerateNodeCreationCodes() node_creation_str = self.node_creation_str - dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" + dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" forward_function_name = GetDygraphForwardFunctionName(forward_api_name) # Forward amp logic @@ -1119,6 +1130,34 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): self.node_definition_str = "" self.next_grad_api_contents = next_grad_api_contents + def ResetOptionalInputs(self): + namespace = self.namespace + grad_api_contents = self.grad_api_contents + + base_generator = FunctionGeneratorBase(grad_api_contents, namespace) + base_generator.ParseDispensable() + + self.optional_inputs = base_generator.optional_inputs + + def GenerateHigherOrderNodeCreationCode(self): + namespace = self.namespace + grad_api_contents = self.grad_api_contents + next_grad_api_contents = self.next_grad_api_contents + + grad_node_creation_str = "" + if next_grad_api_contents: + forward_api_contents = grad_api_contents + forward_api_contents['api'] = forward_api_contents['backward_api'] + backward_api_contents = next_grad_api_contents + + next_node_generator = DygraphFunctionGeneratorBase( + forward_api_contents, backward_api_contents, namespace) + next_node_generator.run() + next_node_generator.GenerateNodeCreationCodes() + grad_node_creation_str = next_node_generator.node_creation_str + + return grad_node_creation_str + def GenerateNodeDeclaration(self): forward_op_name = self.forward_api_name backward_forward_inputs_map = self.backward_forward_inputs_map @@ -1187,6 +1226,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): backward_grad_inputs_map = self.backward_grad_inputs_map backward_grad_outputs_map = self.backward_grad_outputs_map backward_attrs_list = self.backward_attrs_list + indent = GetIndent(1) # Construct grad_api function args # Order: TensorWrappers, GradTensors, Attributes @@ -1197,8 +1237,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # Fill Grad Ins with Zero fill_zero_str = "" - if forward_api_name in ops_to_fill_zero_for_empty_grads: - fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" + if backward_api_name in ops_to_fill_zero_for_empty_grads: + fill_zero_str = f"{indent}egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" # Grad Ins from TensorWrappers for name, (_, is_fwd_input, @@ -1209,9 +1249,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): is_optional = (name in self.optional_inputs) if is_optional: - tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" else: - tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" grad_api_args[grad_api_position] = transformed_tensor_name get_grad_in_args_list.append(tensor_wrapper_recover_str) @@ -1221,18 +1261,29 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( name) + is_optional = (name in self.optional_inputs) if IsPlainTensorType(ttype): - get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];" + get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];" + + if is_optional: + get_tensor_str += "\n" + CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE.format( + transformed_tensor_name, transformed_tensor_name, + transformed_tensor_name, transformed_tensor_name) + grad_api_args[ + grad_api_position] = f"{transformed_tensor_name}_optional" + else: + grad_api_args[grad_api_position] = transformed_tensor_name else: assert IsVectorTensorType(ttype) - get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];" - grad_api_args[grad_api_position] = transformed_tensor_name + get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];" + grad_api_args[grad_api_position] = transformed_tensor_name + get_grad_in_args_list.append(get_tensor_str) # Grad Attrs for name, _, _, grad_api_position in backward_attrs_list: saved_attribute_name = GetSavedName(name) - get_attr_str = f"auto& {name} = this->{saved_attribute_name};" + get_attr_str = f"{indent}auto& {name} = this->{saved_attribute_name};" grad_api_args[grad_api_position] = name get_grad_in_args_list.append(get_attr_str) @@ -1242,7 +1293,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # Grad Function Call String grad_api_namespace = f"paddle::experimental::{namespace}" - grad_function_call_str = f"auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});" + grad_function_call_str = f"{indent}auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});" # Get Grad Outputs get_outputs_str = "" @@ -1253,9 +1304,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): name) if num_outputs == 1: - get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;" + get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;" else: - get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];" + if IsPlainTensorType(ttype): + get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}][0];" + else: + assert IsVectorTensorType(ttype) + get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];" get_outputs_str += get_tensor_str + "\n" # Prepare for Node Creation if Necessary @@ -1274,13 +1329,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): input_autograd_meta_name = GetAutoGradMetaName( transformed_tensor_name) if IsPlainTensorType(ttype): - input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});" + input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});" else: assert IsVectorTensorType(ttype) input_autograd_meta_vec_name = GetAutoGradMetaVectorName( transformed_tensor_name) - input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n" - input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + input_autograd_meta = f"{indent}std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n" + input_autograd_meta += f"{indent}std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" inputs_autograd_meta_list.append(input_autograd_meta) compute_require_grad_args_list.append(input_autograd_meta_name) @@ -1293,13 +1348,13 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): input_autograd_meta_name = GetAutoGradMetaName( transformed_tensor_name) if IsPlainTensorType(ttype): - input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});" + input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});" else: assert IsVectorTensorType(ttype) input_autograd_meta_vec_name = GetAutoGradMetaVectorName( transformed_tensor_name) - input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n" - input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + input_autograd_meta = f"{indent}std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n" + input_autograd_meta += f"{indent}std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" inputs_autograd_meta_list.append(input_autograd_meta) compute_require_grad_args_list.append(input_autograd_meta_name) @@ -1320,30 +1375,30 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): transformed_tensor_name) if num_fwd_outputs == 1: if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});" + output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});" else: assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + output_autograd_meta = f"{indent}std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n" + output_autograd_meta += f"{indent}std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" else: # Tuple api_result if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});" + output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});" else: assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n" - output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + output_autograd_meta = f"{indent}std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n" + output_autograd_meta += f"{indent}std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" outputs_autograd_meta_list.append(output_autograd_meta) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) - compute_require_grad_str = "bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" - compute_require_grad_str += f"bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});" + compute_require_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" + compute_require_grad_str += f"{indent}bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});" # Construct grad_api returns num_bwd_outputs = len(backward_grad_outputs_map.keys()) slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) - returns_str = f"std::vector> returns({slot_num_bwd_outputs});\n" + returns_str = f"{indent}std::vector> returns({slot_num_bwd_outputs});\n" for name, (ttype, fwd_position, grad_api_position) in backward_grad_outputs_map.items(): transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( @@ -1353,15 +1408,20 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): if num_bwd_outputs == 1: # Single tensor output, return as is if IsPlainTensorType(ttype): - returns_str += f"returns[0] = {{ {transformed_tensor_name} }};\n" + returns_str += f"{indent}returns[0] = {{ {transformed_tensor_name} }};\n" else: assert IsVectorTensorType(ttype) - returns_str += f"returns[0] = {transformed_tensor_name};\n" + returns_str += f"{indent}returns[0] = {transformed_tensor_name};\n" else: # Rearrange output order accordingly - returns_str += f"returns[{fwd_position}] = {transformed_tensor_name};\n" - returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" - returns_str += f"return returns;\n" + if IsPlainTensorType(ttype): + returns_str += f"{indent}returns[{fwd_position}] = {{ {transformed_tensor_name} }};\n" + else: + assert IsVectorTensorType(ttype) + returns_str += f"{indent}returns[{fwd_position}] = {transformed_tensor_name};\n" + + returns_str += f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" + returns_str += f"{indent}return returns;\n" grad_node_name = GetGradNodeName(forward_api_name) @@ -1376,24 +1436,15 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): def run(self): super().run() + self.ResetOptionalInputs() + ##################### ## Code Generation ## ##################### self.GenerateNodeDeclaration() - namespace = self.namespace - grad_node_creation_str = "" - next_grad_api_contents = self.next_grad_api_contents - if next_grad_api_contents: - forward_api_contents = self.grad_api_contents - forward_api_contents['api'] = forward_api_contents['backward_api'] - backward_api_contents = next_grad_api_contents - - next_node_generator = DygraphFunctionGeneratorBase( - forward_api_contents, backward_api_contents, namespace) - next_node_generator.run() - next_node_generator.GenerateNodeCreationCodes() - grad_node_creation_str = next_node_generator.node_creation_str + # Higher-order GradNode generation + grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode() self.GenerateNodeDefinition(grad_node_creation_str) diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index ff4445f4261e321aaefee01469c4c2e94636ef9e..0d07f780dda9d1d88cd65c66b1ae0e31ee45cff6 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/hooks.h" #include "paddle/phi/api/all.h" diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 794e2659302473b0b01ac49d4338d959051e567f..9e59d79408b0daf8a6696f38ff8e44df0bcb2607 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -23,6 +23,7 @@ from ...tensor.math import multiply import warnings from ...fluid.layer_helper import LayerHelper from ...fluid.framework import convert_np_dtype_to_dtype_ +from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle from paddle import _C_ops, in_dynamic_mode @@ -560,9 +561,10 @@ def relu(x, name=None): out = F.relu(x) # [0., 0., 1.] """ - if in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_relu(x) + if _in_legacy_dygraph(): return _C_ops.relu(x) - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu') helper = LayerHelper('relu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 3830d7f92689bec9c49ed2c504a0fe279fe53760..c981a068b64babb668cf468c624f0184f47f9730 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -377,15 +377,15 @@ data_type : x - backward_api : matmul_double_grad - forward : matmul_grad (Tensor x, Tensor y, Tensor out_grad, bool transpose_x, bool transpose_y) -> Tensor(dx), Tensor(dy) - args : (Tensor x, Tensor y, Tensor out_grad, Tensor dx_grad, Tensor dy_grad, bool transpose_x, bool transpose_y) - output : Tensor(d2x), Tensor(d2y), Tensor(dout_grad) + forward : matmul_grad (Tensor x, Tensor y, Tensor grad_out, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y) + args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, bool transpose_x=false, bool transpose_y=false) + output : Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad) infer_meta : func : GeneralTernaryGradInferMeta - param : [x, y, out_grad] + param : [x, y, grad_out] kernel : func : matmul_double_grad - optional : dx_grad, dy_grad + optional : grad_x_grad, grad_y_grad - backward_api : matmul_grad forward : matmul (Tensor x, Tensor y, bool transpose_x=false, bool transpose_y=false) -> Tensor(out) @@ -396,6 +396,7 @@ param : [x, y] kernel : func : matmul_grad + backward : matmul_double_grad - backward_api : matrix_power_grad forward : matrix_power (Tensor x, int n) -> Tensor(out)