diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 4f6f437163a8dc86a6c62de2b75923dd022379a2..967891fe5227dcd6129c0ef1808fba7720711568 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -712,18 +712,24 @@ def GenerateNodeCreationCodes( # SetTensorWrappers set_tensor_wrappers_list = [] - for name, (_, is_fwd_input, _) in backward_fwd_input_map.items(): + for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items(): is_optional = (name in optional_inputs) + if is_fwd_input: if is_optional: set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" else: set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" else: + if IsVectorTensorType(atype): + tw_name = f"api_result[{pos}]" + else: + tw_name = f"api_result" + if is_optional: - set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, false);" + set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);" else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);" + set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) @@ -1040,12 +1046,12 @@ def GenerateNodeHFile(filepath, node_declaration_str): def GenerateForwardCCFile(filepath, forward_definition_str): file_contents = """ +#include "paddle/phi/api/lib/dygraph_api.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/phi/api/include/sparse_api.h" #include "paddle/fluid/eager/api/utils/global_utils.h" - """ file_contents += GenerateCoreOpInfoDefinition() diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index abf3f86bdb03b8a3bd89fbb674b7a8cfb534adf2..eee32a2c5057d523212a4faa5eca8678e961f417 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -222,6 +222,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str): #include "pybind11/detail/common.h" #include "paddle/phi/api/all.h" +#include "paddle/phi/api/lib/dygraph_api.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/scalar.h"