From 9ebe72767da0504cd20f5675ad459baa85e5258e Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 11 Mar 2022 09:19:07 +0800 Subject: [PATCH] Fixed issues with intermediate kernels (#40266) * Fix issues with intermediate kernels * Fixed CI issues --- .../final_state_generator/eager_gen.py | 14 ++++++++++---- .../final_state_generator/python_c_gen.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 4f6f437163..967891fe52 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -712,18 +712,24 @@ def GenerateNodeCreationCodes( # SetTensorWrappers set_tensor_wrappers_list = [] - for name, (_, is_fwd_input, _) in backward_fwd_input_map.items(): + for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items(): is_optional = (name in optional_inputs) + if is_fwd_input: if is_optional: set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" else: set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" else: + if IsVectorTensorType(atype): + tw_name = f"api_result[{pos}]" + else: + tw_name = f"api_result" + if is_optional: - set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, false);" + set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);" else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);" + set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) @@ -1040,12 +1046,12 @@ def GenerateNodeHFile(filepath, node_declaration_str): def GenerateForwardCCFile(filepath, forward_definition_str): file_contents = """ +#include "paddle/phi/api/lib/dygraph_api.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/phi/api/include/sparse_api.h" #include "paddle/fluid/eager/api/utils/global_utils.h" - """ file_contents += GenerateCoreOpInfoDefinition() diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index abf3f86bdb..eee32a2c50 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -222,6 +222,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str): #include "pybind11/detail/common.h" #include "paddle/phi/api/all.h" +#include "paddle/phi/api/lib/dygraph_api.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/scalar.h" -- GitLab