diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index a079a9aedffd8bc8def05ab71bd239722fdc9641..5b7315e1ee7b1a2f67b1025bb4f47c0ea051f5d7 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -760,12 +760,10 @@ def GenerateNodeCreationCodes( # SetTensorWrappers set_tensor_wrappers_list = [] - fwd_api_input_num = 0 for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items(): is_optional = (name in optional_inputs) if is_fwd_input: - fwd_api_input_num += 1 if is_optional: set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" else: @@ -777,9 +775,7 @@ def GenerateNodeCreationCodes( fwd_output_pos = forward_outputs_position_map[name][1] tw_name = f"std::get<{fwd_output_pos}>(api_result)" else: - assert IsPlainTensorType(atype), atype - out_pos = pos - fwd_api_input_num - tw_name = f"std::get<{out_pos}>(api_result)" + tw_name = f"api_result" if is_optional: set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"