未验证 提交 9ebe7276 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Fixed issues with intermediate kernels (#40266)

* Fix issues with intermediate kernels

* Fixed CI issues
上级 a40ea45e
...@@ -712,18 +712,24 @@ def GenerateNodeCreationCodes( ...@@ -712,18 +712,24 @@ def GenerateNodeCreationCodes(
# SetTensorWrappers # SetTensorWrappers
set_tensor_wrappers_list = [] set_tensor_wrappers_list = []
for name, (_, is_fwd_input, _) in backward_fwd_input_map.items(): for name, (atype, is_fwd_input, pos) in backward_fwd_input_map.items():
is_optional = (name in optional_inputs) is_optional = (name in optional_inputs)
if is_fwd_input: if is_fwd_input:
if is_optional: if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else: else:
if IsVectorTensorType(atype):
tw_name = f"api_result[{pos}]"
else:
tw_name = f"api_result"
if is_optional: if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, false);" set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);" set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
...@@ -1040,12 +1046,12 @@ def GenerateNodeHFile(filepath, node_declaration_str): ...@@ -1040,12 +1046,12 @@ def GenerateNodeHFile(filepath, node_declaration_str):
def GenerateForwardCCFile(filepath, forward_definition_str): def GenerateForwardCCFile(filepath, forward_definition_str):
file_contents = """ file_contents = """
#include "paddle/phi/api/lib/dygraph_api.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/phi/api/include/sparse_api.h" #include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
""" """
file_contents += GenerateCoreOpInfoDefinition() file_contents += GenerateCoreOpInfoDefinition()
......
...@@ -222,6 +222,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str): ...@@ -222,6 +222,7 @@ def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
#include "pybind11/detail/common.h" #include "pybind11/detail/common.h"
#include "paddle/phi/api/all.h" #include "paddle/phi/api/all.h"
#include "paddle/phi/api/lib/dygraph_api.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册