From c43f8afbd94471166ad6dccdbef87f6cd1bfbe28 Mon Sep 17 00:00:00 2001 From: phlrain Date: Sat, 19 Mar 2022 15:07:48 +0000 Subject: [PATCH] fix generator bug; --- .../auto_code_generator/final_state_generator/eager_gen.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index a079a9aedf..5b7315e1ee 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -760,12 +760,10 @@ def GenerateNodeCreationCodes( # SetTensorWrappers set_tensor_wrappers_list = [] - fwd_api_input_num = 0 for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items(): is_optional = (name in optional_inputs) if is_fwd_input: - fwd_api_input_num += 1 if is_optional: set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" else: @@ -777,9 +775,7 @@ def GenerateNodeCreationCodes( fwd_output_pos = forward_outputs_position_map[name][1] tw_name = f"std::get<{fwd_output_pos}>(api_result)" else: - assert IsPlainTensorType(atype), atype - out_pos = pos - fwd_api_input_num - tw_name = f"std::get<{out_pos}>(api_result)" + tw_name = f"api_result" if is_optional: set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);" -- GitLab