From 25591674b5b6a83a46cb9ec5264e7e5edb851317 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Sun, 27 Mar 2022 08:39:48 +0800 Subject: [PATCH] fix inplace bug in final_state eager_gen (#40979) * fix inplace bug in final_state eager_gen * fix python_c_gen --- .../final_state_generator/eager_gen.py | 20 ++++++++++--------- .../final_state_generator/python_c_gen.py | 13 ++++++++++-- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index b87e7d5f8c1..f23582bdd15 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -807,7 +807,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n" ) else: - if inplace_map and name in inplace_map.keys(): + if is_inplaced and inplace_map and name in inplace_map.keys( + ): arg_str = f"paddle::experimental::Tensor& {name}" amp_tensors_vector_list.append(f"{{{name}}}") amp_autocast_list.append( @@ -881,7 +882,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): returns_str = ", ".join(returns_list) returns_str = f"std::make_tuple({returns_str})" - self.GenerateNodeCreationCodes(forward_call_str) + self.GenerateNodeCreationCodes(forward_call_str, is_inplaced) node_creation_str = self.node_creation_str dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" @@ -917,7 +918,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): logging.info( f"Generated Forward Declaration: {self.forward_declaration_str}") - def GenerateNodeCreationCodes(self, forward_call_str): + def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced): forward_api_name = self.forward_api_name forward_inputs_position_map = self.forward_inputs_position_map forward_outputs_position_map = self.forward_outputs_position_map @@ -980,12 +981,13 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): # Check Inplace check_inplace_str = "" bump_inplace_version_str = "" - for inplace_name in inplace_map.keys(): - inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) - check_inplace_str += CHECK_INPLACE_TEMPLATE.format( - inplace_name, inplace_autograd_meta_name) - bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( - inplace_name, inplace_name) + if is_inplaced: + for inplace_name in inplace_map.keys(): + inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) + check_inplace_str += CHECK_INPLACE_TEMPLATE.format( + inplace_name, inplace_autograd_meta_name) + bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( + inplace_name, inplace_name) # Node Construction num_backward_inputs = len(forward_outputs_position_map.keys()) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index c7be9480f55..4b557f7f5cf 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -333,11 +333,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): forward_api_name, namespace, forward_api_name, forward_api_name) if len(inplace_map) > 0: - inplaced_forward_api_name = GetInplacedFunctionName( - self.forward_api_name) assert len( inplace_map ) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}" + inplaced_forward_api_name = GetInplacedFunctionName( + self.forward_api_name) + # Generate Python-C Function Definitions + if is_forward_only: + fwd_function_name = FUNCTION_NAME_TEMPLATE.format( + "paddle::experimental::", namespace, + inplaced_forward_api_name) + elif len(inplace_map) > 0: + fwd_function_name = FUNCTION_NAME_TEMPLATE.format( + "::", namespace, + GetForwardFunctionName(inplaced_forward_api_name)) for inplace_input, inplace_output in inplace_map.items(): return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format( inplaced_forward_api_name, inplace_input, -- GitLab