未验证 提交 25591674 编写于 作者: P pangyoki 提交者: GitHub

fix inplace bug in final_state eager_gen (#40979)

* fix inplace bug in final_state eager_gen

* fix python_c_gen
上级 52f07ab4
...@@ -807,7 +807,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -807,7 +807,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n" f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n"
) )
else: else:
if inplace_map and name in inplace_map.keys(): if is_inplaced and inplace_map and name in inplace_map.keys(
):
arg_str = f"paddle::experimental::Tensor& {name}" arg_str = f"paddle::experimental::Tensor& {name}"
amp_tensors_vector_list.append(f"{{{name}}}") amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append( amp_autocast_list.append(
...@@ -881,7 +882,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -881,7 +882,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
returns_str = ", ".join(returns_list) returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})" returns_str = f"std::make_tuple({returns_str})"
self.GenerateNodeCreationCodes(forward_call_str) self.GenerateNodeCreationCodes(forward_call_str, is_inplaced)
node_creation_str = self.node_creation_str node_creation_str = self.node_creation_str
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
...@@ -917,7 +918,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -917,7 +918,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
logging.info( logging.info(
f"Generated Forward Declaration: {self.forward_declaration_str}") f"Generated Forward Declaration: {self.forward_declaration_str}")
def GenerateNodeCreationCodes(self, forward_call_str): def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
forward_api_name = self.forward_api_name forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
...@@ -980,12 +981,13 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -980,12 +981,13 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
# Check Inplace # Check Inplace
check_inplace_str = "" check_inplace_str = ""
bump_inplace_version_str = "" bump_inplace_version_str = ""
for inplace_name in inplace_map.keys(): if is_inplaced:
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) for inplace_name in inplace_map.keys():
check_inplace_str += CHECK_INPLACE_TEMPLATE.format( inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
inplace_name, inplace_autograd_meta_name) check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( inplace_name, inplace_autograd_meta_name)
inplace_name, inplace_name) bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
# Node Construction # Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_inputs = len(forward_outputs_position_map.keys())
......
...@@ -333,11 +333,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -333,11 +333,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_name, namespace, forward_api_name, forward_api_name) forward_api_name, namespace, forward_api_name, forward_api_name)
if len(inplace_map) > 0: if len(inplace_map) > 0:
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
assert len( assert len(
inplace_map inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}" ) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
# Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace,
inplaced_forward_api_name)
elif len(inplace_map) > 0:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
for inplace_input, inplace_output in inplace_map.items(): for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format( return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
inplaced_forward_api_name, inplace_input, inplaced_forward_api_name, inplace_input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册