未验证 提交 25591674 编写于 作者: P pangyoki 提交者: GitHub

fix inplace bug in final_state eager_gen (#40979)

* fix inplace bug in final_state eager_gen

* fix python_c_gen
上级 52f07ab4
......@@ -807,7 +807,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n"
)
else:
if inplace_map and name in inplace_map.keys():
if is_inplaced and inplace_map and name in inplace_map.keys(
):
arg_str = f"paddle::experimental::Tensor& {name}"
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
......@@ -881,7 +882,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})"
self.GenerateNodeCreationCodes(forward_call_str)
self.GenerateNodeCreationCodes(forward_call_str, is_inplaced)
node_creation_str = self.node_creation_str
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
......@@ -917,7 +918,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
logging.info(
f"Generated Forward Declaration: {self.forward_declaration_str}")
def GenerateNodeCreationCodes(self, forward_call_str):
def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
......@@ -980,12 +981,13 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
# Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys())
......
......@@ -333,11 +333,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_name, namespace, forward_api_name, forward_api_name)
if len(inplace_map) > 0:
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
# Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace,
inplaced_forward_api_name)
elif len(inplace_map) > 0:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
inplaced_forward_api_name, inplace_input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册