diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index fe0b021a2e5b5a24a358b9b26b558b8ccf00dd83..aa77a4290f98a0c8c00429ddb9c5a4110ee34fae 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -1780,6 +1780,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name def GenerateHigherOrderNodeCreationCode(self): + has_higher_order_node = False namespace = self.namespace grad_api_contents = self.grad_api_contents forward_apis_dict = self.forward_apis_dict @@ -1807,14 +1808,33 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): next_grad_node_out_list = next_node_generator.grad_node_out_list self.RecordGrad2NextGradNameMapping(next_node_generator) + + is_invoke_forward_api = IsInvokeForwardApi( + self.grad_api_contents, self.forward_apis_dict + ) if next_node_generator is not None: + has_higher_order_node = True return ( + has_higher_order_node, + is_invoke_forward_api, next_grad_node_creation_str, next_grad_node_out_list, next_node_generator.backward_forward_inputs_map, ) - else: - return next_grad_node_creation_str, next_grad_node_out_list, None + elif not is_invoke_forward_api: + next_grad_node_creation_str = f""" if(trace_backward) {{ + PADDLE_THROW(phi::errors::Unavailable( + \"The Op {self.backward_api_name} doesn't have any grad\" + \"op. If you don't intend calculating higher order\" + \"derivatives, please set `create_graph`to False.\")); + }}""" + return ( + has_higher_order_node, + is_invoke_forward_api, + next_grad_node_creation_str, + next_grad_node_out_list, + None, + ) def GenerateNodeDeclaration(self): forward_op_name = self.forward_api_name @@ -1906,6 +1926,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): def GenerateNodeDefinition( self, + has_higher_order_node, + is_invoke_forward_api, next_grad_node_creation_str, next_grad_node_out_list, backward_forward_inputs_map_next, @@ -1920,9 +1942,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): backward_inplace_map = self.backward_inplace_map indent = GetIndent(1) - is_invoke_forward_api = IsInvokeForwardApi( - self.grad_api_contents, self.forward_apis_dict - ) # Construct grad_api function args # Order: TensorWrappers, GradTensors, Attributes grad_api_args_len = ( @@ -1968,7 +1987,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): is_optional = name in self.optional_inputs tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name});" if backward_inplace_map and name in backward_inplace_map.keys(): - if len(next_grad_node_creation_str) > 0: + if has_higher_order_node: if ( transformed_tensor_name in backward_forward_inputs_map_next @@ -2039,7 +2058,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # Inplace in backward op if backward_inplace_map and name in backward_inplace_map.keys(): - if len(next_grad_node_creation_str) > 0: + if has_higher_order_node: if ( transformed_tensor_name in backward_forward_inputs_map_next @@ -2144,7 +2163,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): inplace_str = f""" if (api_output_{out_index} != nullptr && can_be_inplaced) {{ egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index}); }}""" - if len(next_grad_node_creation_str) > 0: + if has_higher_order_node: inplace_for_grad_outs_str += f""" if (trace_backward) {{ {optional_inplace_str} @@ -2236,7 +2255,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): else: assert IsVectorTensorType(rtype) - if len(next_grad_node_creation_str) > 0: + if has_higher_order_node > 0: output_autograd_meta = f""" auto& {transformed_tensor_name} = returns[{pos}]; std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name}); @@ -2327,6 +2346,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ##################### # Higher-order GradNode generation ( + has_higher_order_node, + is_invoke_forward_api, next_grad_node_creation_str, next_grad_node_out_list, backward_forward_inputs_map, @@ -2335,6 +2356,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): self.GenerateNodeDeclaration() self.GenerateNodeDefinition( + has_higher_order_node, + is_invoke_forward_api, next_grad_node_creation_str, next_grad_node_out_list, backward_forward_inputs_map,