未验证 提交 6cd79701 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] split coed gen for eager fluid_generated (#44177)

* split coed gen for eager fluid_generated
上级 be746adf
...@@ -65,8 +65,7 @@ paddle/infrt/dialect/pd/common/pd_ops_info.h ...@@ -65,8 +65,7 @@ paddle/infrt/dialect/pd/common/pd_ops_info.h
paddle/infrt/tests/dialect/Output paddle/infrt/tests/dialect/Output
paddle/infrt/tests/lit.cfg.py paddle/infrt/tests/lit.cfg.py
paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc
paddle/fluid/pybind/eager_final_state_op_function_impl.h paddle/fluid/pybind/eager_final_state_op_function.cc
paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h
# these files (directories) are generated before build system generation # these files (directories) are generated before build system generation
paddle/fluid/operators/generated_op.cc paddle/fluid/operators/generated_op.cc
......
...@@ -26,36 +26,14 @@ endif() ...@@ -26,36 +26,14 @@ endif()
message( message(
"Generate dygraph file structure at path: ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/generated" "Generate dygraph file structure at path: ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/generated"
) )
set(CODE_GEN_SPLIT_FILE_COUNT "8")
execute_process( execute_process(
COMMAND COMMAND
"${PYTHON_EXECUTABLE}" "${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generate_file_structures.py" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generate_file_structures.py"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/") "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/" "${CODE_GEN_SPLIT_FILE_COUNT}")
set(tmp_dygraph_forward_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.tmp.h"
)
set(tmp_dygraph_forward_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/forwards/dygraph_forward_functions.tmp.cc"
)
set(tmp_dygraph_node_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.tmp.h"
)
set(tmp_dygraph_node_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.tmp.cc"
)
set(dygraph_forward_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
)
set(dygraph_forward_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/forwards/dygraph_forward_functions.cc"
)
set(dygraph_node_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h"
)
set(dygraph_node_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.cc"
)
if(WIN32) if(WIN32)
set(EAGER_CODEGEN_DEPS eager_generator) set(EAGER_CODEGEN_DEPS eager_generator)
...@@ -114,22 +92,7 @@ if(WIN32) ...@@ -114,22 +92,7 @@ if(WIN32)
COMMAND COMMAND
"${eager_generator_path}/eager_generator.exe" "${eager_generator_path}/eager_generator.exe"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_forward_h_path} "${CODE_GEN_SPLIT_FILE_COUNT}"
${dygraph_forward_h_path}
COMMENT
"copy_if_different ${tmp_dygraph_forward_h_path} to ${dygraph_forward_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_forward_cc_path}
${dygraph_forward_cc_path}
COMMENT
"copy_if_different ${tmp_dygraph_forward_cc_path} to ${dygraph_forward_cc_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_node_h_path}
${dygraph_node_h_path}
COMMENT
"copy_if_different ${tmp_dygraph_node_h_path} to ${dygraph_node_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_node_cc_path}
${dygraph_node_cc_path}
COMMENT
"copy_if_different ${tmp_dygraph_node_cc_path} to ${dygraph_node_cc_path}"
DEPENDS ${EAGER_CODEGEN_DEPS} DEPENDS ${EAGER_CODEGEN_DEPS}
VERBATIM) VERBATIM)
else() else()
...@@ -140,22 +103,7 @@ else() ...@@ -140,22 +103,7 @@ else()
"LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind" "LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind"
"${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${CMAKE_CURRENT_BINARY_DIR}/eager_generator"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_forward_h_path} "${CODE_GEN_SPLIT_FILE_COUNT}"
${dygraph_forward_h_path}
COMMENT
"copy_if_different ${tmp_dygraph_forward_h_path} to ${dygraph_forward_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_forward_cc_path}
${dygraph_forward_cc_path}
COMMENT
"copy_if_different ${tmp_dygraph_forward_cc_path} to ${dygraph_forward_cc_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_node_h_path}
${dygraph_node_h_path}
COMMENT
"copy_if_different ${tmp_dygraph_node_h_path} to ${dygraph_node_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_node_cc_path}
${dygraph_node_cc_path}
COMMENT
"copy_if_different ${tmp_dygraph_node_cc_path} to ${dygraph_node_cc_path}"
DEPENDS eager_generator DEPENDS eager_generator
VERBATIM) VERBATIM)
endif() endif()
...@@ -3108,7 +3108,8 @@ static std::string GenerateCoreOpsReturnsInfo() { ...@@ -3108,7 +3108,8 @@ static std::string GenerateCoreOpsReturnsInfo() {
return core_ops_info_str; return core_ops_info_str;
} }
static void DygraphCodeGeneration(const std::string& output_dir) { static void DygraphCodeGeneration(const std::string& output_dir,
int split_count) {
std::string dygraph_forward_api_str = GenerateDygraphHFileIncludes(); std::string dygraph_forward_api_str = GenerateDygraphHFileIncludes();
std::string fwd_function_str = ""; std::string fwd_function_str = "";
std::string grad_node_h_str = ""; std::string grad_node_h_str = "";
...@@ -3116,6 +3117,8 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -3116,6 +3117,8 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
paddle::flat_hash_map<std::string, OpInfo> op_info_map_need_gen;
for (auto& pair : op_info_map) { for (auto& pair : op_info_map) {
const OpInfo& op_info = pair.second; const OpInfo& op_info = pair.second;
proto::OpProto* op_proto = op_info.proto_; proto::OpProto* op_proto = op_info.proto_;
...@@ -3126,6 +3129,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -3126,6 +3129,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
continue; continue;
} }
GradNodeGenerationInfo bwd_info;
bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info);
if (!is_available && !bwd_info.GenerateForwardOnly()) {
VLOG(6) << "Skipped operator: " << op_type;
continue;
}
op_info_map_need_gen.emplace(pair);
}
int each_cc_file_api_size = op_info_map_need_gen.size() / split_count;
if (op_info_map_need_gen.size() % split_count != 0) {
each_cc_file_api_size++;
}
int api_index = 0;
int file_index = 0;
for (auto& pair : op_info_map_need_gen) {
const OpInfo& op_info = pair.second;
proto::OpProto* op_proto = op_info.proto_;
const std::string& op_type = op_proto->type();
/* ----------------------------- */ /* ----------------------------- */
/* ---- Collect Information ---- */ /* ---- Collect Information ---- */
/* ----------------------------- */ /* ----------------------------- */
...@@ -3137,12 +3165,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -3137,12 +3165,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
CollectForwardInformationFromOpInfo(op_info, &fwd_info); CollectForwardInformationFromOpInfo(op_info, &fwd_info);
bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info); CollectGradInformationFromOpInfo(op_info, &bwd_info);
if (!is_available && !bwd_info.GenerateForwardOnly()) {
VLOG(6) << "Skipped operator: " << op_type;
continue;
}
VLOG(6) << "-------- PurifyOpProto -------"; VLOG(6) << "-------- PurifyOpProto -------";
PurifyForwardOpProto(*op_proto, &fwd_info); PurifyForwardOpProto(*op_proto, &fwd_info);
...@@ -3188,8 +3211,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -3188,8 +3211,7 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
dygraph_forward_api_str += inplace_fwd_function_declare_str; dygraph_forward_api_str += inplace_fwd_function_declare_str;
} }
if (bwd_info.GenerateForwardOnly()) continue; if (!bwd_info.GenerateForwardOnly()) {
VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; VLOG(6) << "-------- GenerateGradNodeHeaderContents -------";
grad_node_h_str += GenerateGradNodeHeaderContents(fwd_info, bwd_info); grad_node_h_str += GenerateGradNodeHeaderContents(fwd_info, bwd_info);
grad_node_h_str += "\n"; grad_node_h_str += "\n";
...@@ -3197,16 +3219,46 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -3197,16 +3219,46 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
VLOG(6) << "-------- GenerateGradNodeCCContents -------"; VLOG(6) << "-------- GenerateGradNodeCCContents -------";
grad_node_cc_str += GenerateGradNodeCCContents(fwd_info, bwd_info); grad_node_cc_str += GenerateGradNodeCCContents(fwd_info, bwd_info);
grad_node_cc_str += "\n"; grad_node_cc_str += "\n";
}
VLOG(6) << op_type << ": Finished Generating Op: " << op_type; VLOG(6) << op_type << ": Finished Generating Op: " << op_type;
}
api_index++;
if (api_index / each_cc_file_api_size > file_index) {
file_index++;
VLOG(6) << "-------- GenerateDygraphForwardCCFile -------"; VLOG(6) << "-------- GenerateDygraphForwardCCFile -------";
std::string forward_cc_path = std::string forward_cc_path = output_dir +
output_dir + "/forwards/dygraph_forward_functions.tmp.cc"; "/forwards/dygraph_forward_functions" +
std::to_string(file_index) + ".tmp.cc";
fwd_function_str += "\n"; fwd_function_str += "\n";
fwd_function_str += GenerateCoreOpsReturnsInfo();
GenerateForwardDygraphFile(forward_cc_path, fwd_function_str); GenerateForwardDygraphFile(forward_cc_path, fwd_function_str);
fwd_function_str = "";
VLOG(6) << "-------- GenerateNodeCCFile -------";
std::string node_cc_path =
output_dir + "/nodes/nodes" + std::to_string(file_index) + ".tmp.cc";
GenerateNodeCCFile(node_cc_path, grad_node_cc_str);
grad_node_cc_str = "";
}
}
file_index++;
VLOG(6) << "-------- GenerateDygraphForwardCCFile -------";
std::string forward_cc_path = output_dir +
"/forwards/dygraph_forward_functions" +
std::to_string(file_index) + ".tmp.cc";
GenerateForwardDygraphFile(forward_cc_path, fwd_function_str);
fwd_function_str = "";
GenerateForwardDygraphFile(
output_dir + "/forwards/dygraph_forward_functions_args_info.tmp.cc",
GenerateCoreOpsReturnsInfo());
VLOG(6) << "-------- GenerateNodeCCFile -------";
std::string node_cc_path =
output_dir + "/nodes/nodes" + std::to_string(file_index) + ".tmp.cc";
GenerateNodeCCFile(node_cc_path, grad_node_cc_str);
grad_node_cc_str = "";
VLOG(6) << "-------- GenerateForwardHFile -------"; VLOG(6) << "-------- GenerateForwardHFile -------";
std::string dygraph_forward_api_path = std::string dygraph_forward_api_path =
...@@ -3216,26 +3268,23 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -3216,26 +3268,23 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
VLOG(6) << "-------- GenerateNodeHFile -------"; VLOG(6) << "-------- GenerateNodeHFile -------";
std::string node_h_path = output_dir + "/nodes/nodes.tmp.h"; std::string node_h_path = output_dir + "/nodes/nodes.tmp.h";
GenerateNodeHFile(node_h_path, grad_node_h_str); GenerateNodeHFile(node_h_path, grad_node_h_str);
VLOG(6) << "-------- GenerateNodeCCFile -------";
std::string node_cc_path = output_dir + "/nodes/nodes.tmp.cc";
GenerateNodeCCFile(node_cc_path, grad_node_cc_str);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc != 2) { if (argc != 3) {
std::cerr << "argc must be 2" << std::endl; std::cerr << "argc must be 3" << std::endl;
return -1; return -1;
} }
std::string eager_root = argv[1]; std::string eager_root = argv[1];
int split_count = atoi(argv[2]);
paddle::framework::PrepareAttrMapForOps(); paddle::framework::PrepareAttrMapForOps();
paddle::framework::DygraphCodeGeneration(eager_root); paddle::framework::DygraphCodeGeneration(eager_root, split_count);
return 0; return 0;
} }
...@@ -54,11 +54,10 @@ add_custom_target( ...@@ -54,11 +54,10 @@ add_custom_target(
VERBATIM) VERBATIM)
set(tmp_python_c_output_path set(tmp_python_c_output_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h" "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function.cc.tmp"
) )
set(python_c_output_path set(python_c_output_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function_impl.h" "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function.cc")
)
add_custom_target( add_custom_target(
eager_final_state_python_c_codegen eager_final_state_python_c_codegen
......
...@@ -139,22 +139,16 @@ PYTHON_C_FUNCTION_REG_TEMPLATE = \ ...@@ -139,22 +139,16 @@ PYTHON_C_FUNCTION_REG_TEMPLATE = \
PYTHON_C_WRAPPER_TEMPLATE = \ PYTHON_C_WRAPPER_TEMPLATE = \
""" """
#pragma once #include <Python.h>
#include "paddle/fluid/platform/enforce.h"
#include "pybind11/detail/common.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/api/lib/dygraph_api.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/phi/api/include/strings_api.h" #include "paddle/phi/api/include/strings_api.h"
#include "paddle/fluid/pybind/op_function_common.h" #include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include <Python.h> #include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/eager_final_state_custom_python_api.h"
#include "paddle/fluid/pybind/eager.h"
namespace paddle {{ namespace paddle {{
namespace pybind {{ namespace pybind {{
...@@ -165,6 +159,16 @@ static PyMethodDef EagerFinalStateMethods[] = {{ ...@@ -165,6 +159,16 @@ static PyMethodDef EagerFinalStateMethods[] = {{
{} {}
}}; }};
void BindFinalStateEagerOpFunctions(pybind11::module *module) {{
if (PyModule_AddFunctions(module->ptr(), EagerFinalStateMethods) < 0) {{
PADDLE_THROW(platform::errors::Fatal ("Add functions to core.eager.ops failed!"));
}}
if (PyModule_AddFunctions(module->ptr(), CustomEagerFinalStateMethods) < 0) {{
PADDLE_THROW(platform::errors::Fatal ("Add functions to core.eager.ops failed!"));
}}
}}
}} // namespace pybind }} // namespace pybind
}} // namespace paddle }} // namespace paddle
""" """
...@@ -449,8 +453,8 @@ class PythonCGenerator(GeneratorBase): ...@@ -449,8 +453,8 @@ class PythonCGenerator(GeneratorBase):
def GeneratePythonCFunctions(self): def GeneratePythonCFunctions(self):
namespace = self.namespace namespace = self.namespace
forward_api_list = self.forward_api_list
forward_api_list = self.forward_api_list
for forward_api_content in forward_api_list: for forward_api_content in forward_api_list:
f_generator = PythonCSingleFunctionGenerator( f_generator = PythonCSingleFunctionGenerator(
forward_api_content, namespace) forward_api_content, namespace)
......
...@@ -53,7 +53,7 @@ def GenerateFileStructureForFinalDygraph(eager_dir): ...@@ -53,7 +53,7 @@ def GenerateFileStructureForFinalDygraph(eager_dir):
open(path, 'a').close() open(path, 'a').close()
def GenerateFileStructureForIntermediateDygraph(eager_dir): def GenerateFileStructureForIntermediateDygraph(eager_dir, split_count):
""" """
paddle/fluid/eager paddle/fluid/eager
|- generated |- generated
...@@ -86,11 +86,16 @@ def GenerateFileStructureForIntermediateDygraph(eager_dir): ...@@ -86,11 +86,16 @@ def GenerateFileStructureForIntermediateDygraph(eager_dir):
dygraph_forward_api_h_path = os.path.join(generated_dir, dygraph_forward_api_h_path = os.path.join(generated_dir,
"dygraph_forward_api.h") "dygraph_forward_api.h")
empty_files = [dygraph_forward_api_h_path] empty_files = [dygraph_forward_api_h_path]
empty_files.append(
os.path.join(forwards_dir, "dygraph_forward_functions.cc"))
empty_files.append(os.path.join(nodes_dir, "nodes.cc"))
empty_files.append(os.path.join(nodes_dir, "nodes.h")) empty_files.append(os.path.join(nodes_dir, "nodes.h"))
for i in range(split_count):
empty_files.append(
os.path.join(forwards_dir,
"dygraph_forward_functions" + str(i + 1) + ".cc"))
empty_files.append(os.path.join(nodes_dir,
"nodes" + str(i + 1) + ".cc"))
empty_files.append(
os.path.join(forwards_dir, "dygraph_forward_functions_args_info.cc"))
for path in empty_files: for path in empty_files:
if not os.path.exists(path): if not os.path.exists(path):
open(path, 'a').close() open(path, 'a').close()
...@@ -102,23 +107,62 @@ def GenerateFileStructureForIntermediateDygraph(eager_dir): ...@@ -102,23 +107,62 @@ def GenerateFileStructureForIntermediateDygraph(eager_dir):
forwards_level_cmakelist_path = os.path.join(forwards_dir, "CMakeLists.txt") forwards_level_cmakelist_path = os.path.join(forwards_dir, "CMakeLists.txt")
with open(nodes_level_cmakelist_path, "w") as f: with open(nodes_level_cmakelist_path, "w") as f:
f.write("add_custom_target(\n")
f.write(" copy_dygraph_node\n")
f.write( f.write(
"cc_library(dygraph_node SRCS nodes.cc DEPS ${eager_deps} ${fluid_deps} ${fluid_manual_nodes})\n" " COMMAND ${CMAKE_COMMAND} -E copy_if_different \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.tmp.h\" \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h\"\n"
) )
f.write("add_dependencies(dygraph_node eager_codegen)") for i in range(split_count):
f.write(
" COMMAND ${CMAKE_COMMAND} -E copy_if_different \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes"
+ str(i + 1) +
".tmp.cc\" \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes"
+ str(i + 1) + ".cc\"\n")
f.write(" DEPENDS eager_codegen\n")
f.write(" VERBATIM)\n")
f.write("cc_library(dygraph_node SRCS ")
for i in range(split_count):
f.write("nodes" + str(i + 1) + ".cc ")
f.write("DEPS ${eager_deps} ${fluid_deps} ${fluid_manual_nodes})\n")
f.write("add_dependencies(dygraph_node copy_dygraph_node)")
with open(forwards_level_cmakelist_path, "w") as f: with open(forwards_level_cmakelist_path, "w") as f:
f.write("add_custom_target(\n")
f.write(" copy_dygraph_forward_functions\n")
f.write( f.write(
"cc_library(dygraph_function SRCS dygraph_forward_functions.cc DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ${fluid_manual_functions})\n" " COMMAND ${CMAKE_COMMAND} -E copy_if_different \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.tmp.h\" \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h\"\n"
) )
f.write("add_dependencies(dygraph_function eager_codegen)") for i in range(split_count):
f.write(
" COMMAND ${CMAKE_COMMAND} -E copy_if_different \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/forwards/dygraph_forward_functions"
+ str(i + 1) +
".tmp.cc\" \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/forwards/dygraph_forward_functions"
+ str(i + 1) + ".cc\"\n")
f.write(
" COMMAND ${CMAKE_COMMAND} -E copy_if_different \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/forwards/dygraph_forward_functions_args_info.tmp.cc\" \"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/forwards/dygraph_forward_functions_args_info.cc\"\n"
)
f.write(" DEPENDS eager_codegen\n")
f.write(" VERBATIM)\n")
f.write("cc_library(dygraph_function SRCS ")
for i in range(split_count):
f.write("dygraph_forward_functions" + str(i + 1) + ".cc ")
f.write("dygraph_forward_functions_args_info.cc ")
f.write(
"DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ${fluid_manual_functions})\n"
)
f.write(
"add_dependencies(dygraph_function copy_dygraph_forward_functions)")
with open(generated_level_cmakelist_path, "w") as f: with open(generated_level_cmakelist_path, "w") as f:
f.write("add_subdirectory(forwards)\nadd_subdirectory(nodes)") f.write("add_subdirectory(forwards)\nadd_subdirectory(nodes)")
if __name__ == "__main__": if __name__ == "__main__":
assert len(sys.argv) == 2 assert len(sys.argv) == 3
eager_dir = sys.argv[1] eager_dir = sys.argv[1]
GenerateFileStructureForIntermediateDygraph(eager_dir) split_count = int(sys.argv[2])
GenerateFileStructureForIntermediateDygraph(eager_dir, split_count)
GenerateFileStructureForFinalDygraph(eager_dir) GenerateFileStructureForFinalDygraph(eager_dir)
pybind.h pybind.h
op_function_impl.h op_function.cc
eager_op_function_impl.h eager_op_function.cc
eager_final_state_op_function_impl.h eager_final_state_op_function.cc
tmp_eager_final_state_op_function_impl.h
...@@ -101,11 +101,16 @@ endif() ...@@ -101,11 +101,16 @@ endif()
set(PYBIND_SRCS set(PYBIND_SRCS
pybind.cc pybind.cc
exception.cc imperative.cc
op_function.cc
inference_api.cc
ir.cc
bind_fleet_executor.cc
reader_py.cc
protobuf.cc protobuf.cc
exception.cc
const_value.cc const_value.cc
global_value_getter_setter.cc global_value_getter_setter.cc
reader_py.cc
fleet_wrapper_py.cc fleet_wrapper_py.cc
heter_wrapper_py.cc heter_wrapper_py.cc
ps_gpu_wrapper_py.cc ps_gpu_wrapper_py.cc
...@@ -113,11 +118,7 @@ set(PYBIND_SRCS ...@@ -113,11 +118,7 @@ set(PYBIND_SRCS
box_helper_py.cc box_helper_py.cc
metrics_py.cc metrics_py.cc
data_set_py.cc data_set_py.cc
imperative.cc
ir.cc
bind_cost_model.cc bind_cost_model.cc
bind_fleet_executor.cc
inference_api.cc
compatible.cc compatible.cc
io.cc io.cc
generator_py.cc generator_py.cc
...@@ -125,6 +126,12 @@ set(PYBIND_SRCS ...@@ -125,6 +126,12 @@ set(PYBIND_SRCS
cuda_streams_py.cc cuda_streams_py.cc
jit.cc) jit.cc)
execute_process(
COMMAND
"${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/generate_file_structures.py"
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/")
if(WITH_CUSTOM_DEVICE) if(WITH_CUSTOM_DEVICE)
set(PYBIND_DEPS ${PYBIND_DEPS} phi_capi) set(PYBIND_DEPS ${PYBIND_DEPS} phi_capi)
endif() endif()
...@@ -189,7 +196,8 @@ if(WITH_PSCORE) ...@@ -189,7 +196,8 @@ if(WITH_PSCORE)
set_source_files_properties( set_source_files_properties(
fleet_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) fleet_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
list(APPEND PYBIND_DEPS fleet communicator index_wrapper index_sampler) list(APPEND PYBIND_DEPS fleet communicator index_wrapper index_sampler)
list(APPEND PYBIND_SRCS fleet_py.cc) list(APPEND PYBIND_SRCS)
set(PYBIND_SRCS fleet_py.cc ${PYBIND_SRCS})
endif() endif()
if(WITH_NCCL OR WITH_RCCL) if(WITH_NCCL OR WITH_RCCL)
...@@ -259,10 +267,10 @@ if(WITH_PYTHON) ...@@ -259,10 +267,10 @@ if(WITH_PYTHON)
target_link_libraries(kernel_signature_generator ${ROCM_HIPRTC_LIB}) target_link_libraries(kernel_signature_generator ${ROCM_HIPRTC_LIB})
endif() endif()
set(impl_file ${CMAKE_SOURCE_DIR}/paddle/fluid/pybind/op_function_impl.h) set(impl_file ${CMAKE_SOURCE_DIR}/paddle/fluid/pybind/op_function.cc)
set(tmp_impl_file ${impl_file}.tmp) set(tmp_impl_file ${impl_file}.tmp)
set(eager_impl_file set(eager_impl_file
${CMAKE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function_impl.h) ${CMAKE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.cc)
set(tmp_eager_impl_file ${eager_impl_file}.tmp) set(tmp_eager_impl_file ${eager_impl_file}.tmp)
set(OP_IMPL_DEPS op_function_generator) set(OP_IMPL_DEPS op_function_generator)
...@@ -461,30 +469,31 @@ if(WITH_PYTHON) ...@@ -461,30 +469,31 @@ if(WITH_PYTHON)
list(APPEND PYBIND_DEPS op_function_common) list(APPEND PYBIND_DEPS op_function_common)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_library( set(PYBIND_SRCS eager.cc ${PYBIND_SRCS})
paddle_eager set(PYBIND_SRCS eager_functions.cc ${PYBIND_SRCS})
SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc set(PYBIND_SRCS eager_method.cc ${PYBIND_SRCS})
eager_utils.cc eager_py_layer.cc set(PYBIND_SRCS eager_properties.cc ${PYBIND_SRCS})
DEPS eager_api set(PYBIND_SRCS eager_utils.cc ${PYBIND_SRCS})
autograd_meta set(PYBIND_SRCS eager_py_layer.cc ${PYBIND_SRCS})
backward set(PYBIND_SRCS eager_op_function.cc ${PYBIND_SRCS})
grad_node_info set(PYBIND_SRCS eager_final_state_op_function.cc ${PYBIND_SRCS})
phi list(APPEND PYBIND_DEPS eager_api)
op_function_common list(APPEND PYBIND_DEPS autograd_meta)
final_dygraph_function list(APPEND PYBIND_DEPS backward)
final_dygraph_node list(APPEND PYBIND_DEPS grad_node_info)
dygraph_function list(APPEND PYBIND_DEPS phi)
dygraph_node list(APPEND PYBIND_DEPS op_function_common)
accumulation_node list(APPEND PYBIND_DEPS final_dygraph_function)
py_layer_node list(APPEND PYBIND_DEPS final_dygraph_node)
global_utils list(APPEND PYBIND_DEPS dygraph_function)
utils list(APPEND PYBIND_DEPS dygraph_node)
python list(APPEND PYBIND_DEPS accumulation_node)
custom_operator list(APPEND PYBIND_DEPS py_layer_node)
custom_operator_node) list(APPEND PYBIND_DEPS global_utils)
add_dependencies(paddle_eager eager_codegen) list(APPEND PYBIND_DEPS utils)
add_dependencies(paddle_eager eager_op_function_generator_cmd) list(APPEND PYBIND_DEPS python)
list(APPEND PYBIND_DEPS paddle_eager) list(APPEND PYBIND_DEPS custom_operator)
list(APPEND PYBIND_DEPS custom_operator_node)
endif() endif()
cc_library( cc_library(
...@@ -492,6 +501,11 @@ if(WITH_PYTHON) ...@@ -492,6 +501,11 @@ if(WITH_PYTHON)
SRCS ${PYBIND_SRCS} SRCS ${PYBIND_SRCS}
DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_dependencies(paddle_pybind eager_codegen)
add_dependencies(paddle_pybind eager_op_function_generator_cmd)
endif()
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
target_link_libraries(paddle_pybind rt) target_link_libraries(paddle_pybind rt)
endif() endif()
......
...@@ -33,7 +33,7 @@ limitations under the License. */ ...@@ -33,7 +33,7 @@ limitations under the License. */
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#pragma GCC diagnostic ignored "-Wmissing-field-initializers" #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/pybind/eager_op_function_impl.h" #include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/string_tensor.h" #include "paddle/phi/core/string_tensor.h"
......
...@@ -40,6 +40,7 @@ void BindEager(pybind11::module* m); ...@@ -40,6 +40,7 @@ void BindEager(pybind11::module* m);
void BindEagerStringTensor(pybind11::module* module); void BindEagerStringTensor(pybind11::module* module);
void BindFunctions(PyObject* module); void BindFunctions(PyObject* module);
void BindEagerPyLayer(PyObject* module); void BindEagerPyLayer(PyObject* module);
void BindEagerOpFunctions(pybind11::module* module);
void BindFinalStateEagerOpFunctions(pybind11::module* module);
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -15,8 +15,12 @@ ...@@ -15,8 +15,12 @@
#include <iostream> #include <iostream>
#include "paddle/fluid/eager/to_static/run_program_op_func.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
namespace paddle {
namespace pybind {
static PyObject *eager_api_run_program(PyObject *self, static PyObject *eager_api_run_program(PyObject *self,
PyObject *args, PyObject *args,
PyObject *kwargs) { PyObject *kwargs) {
...@@ -57,55 +61,12 @@ static PyObject *eager_api_run_program(PyObject *self, ...@@ -57,55 +61,12 @@ static PyObject *eager_api_run_program(PyObject *self,
} }
} }
static PyObject *eager_api_final_state_linear(PyObject *self, static PyMethodDef CustomEagerMethods[] = {
PyObject *args,
PyObject *kwargs) {
PyThreadState *tstate = nullptr;
try {
auto x = GetTensorFromArgs("linear", "X", args, 0, false);
auto weight = GetTensorFromArgs("linear", "weight", args, 1, false);
auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true);
tstate = PyEval_SaveThread();
if (bias.initialized()) {
auto mm_out =
matmul_final_state_dygraph_function(x, weight, false, false);
auto out = add_final_state_dygraph_function(mm_out, bias);
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
} else {
auto mm_out =
matmul_final_state_dygraph_function(x, weight, false, false);
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(mm_out);
}
} catch (paddle::platform::EnforceNotMet &exception) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
std::ostringstream sout;
sout << exception.what();
sout << " [operator < linear > error]";
exception.set_error_str(sout.str());
ThrowExceptionToPython(std::current_exception());
return nullptr;
} catch (...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyMethodDef CustomEagerFinalStateMethods[] = {
{"run_program", {"run_program",
(PyCFunction)(void (*)(void))eager_api_run_program, (PyCFunction)(void (*)(void))eager_api_run_program,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
"C++ interface function for run_program in dygraph."}, "C++ interface function for run_program in dygraph."},
{"final_state_linear",
(PyCFunction)(void (*)(void))eager_api_final_state_linear,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for run_program in dygraph."},
{nullptr, nullptr, 0, nullptr}}; {nullptr, nullptr, 0, nullptr}};
} // namespace pybind
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace pybind {
static PyObject *eager_api_final_state_linear(PyObject *self,
PyObject *args,
PyObject *kwargs) {
PyThreadState *tstate = nullptr;
try {
auto x = GetTensorFromArgs("linear", "X", args, 0, false);
auto weight = GetTensorFromArgs("linear", "weight", args, 1, false);
auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true);
tstate = PyEval_SaveThread();
if (bias.initialized()) {
auto mm_out =
matmul_final_state_dygraph_function(x, weight, false, false);
auto out = add_final_state_dygraph_function(mm_out, bias);
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
} else {
auto mm_out =
matmul_final_state_dygraph_function(x, weight, false, false);
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(mm_out);
}
} catch (paddle::platform::EnforceNotMet &exception) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
std::ostringstream sout;
sout << exception.what();
sout << " [operator < linear > error]";
exception.set_error_str(sout.str());
ThrowExceptionToPython(std::current_exception());
return nullptr;
} catch (...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyMethodDef CustomEagerFinalStateMethods[] = {
{"final_state_linear",
(PyCFunction)(void (*)(void))eager_api_final_state_linear,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for run_program in dygraph."},
{nullptr, nullptr, 0, nullptr}};
} // namespace pybind
} // namespace paddle
...@@ -138,8 +138,6 @@ const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, M ...@@ -138,8 +138,6 @@ const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, M
// These operators will skip automatical code generatrion and // These operators will skip automatical code generatrion and
// need to be handwritten in CUSTOM_HANDWRITE_OP_FUNC_FILE // need to be handwritten in CUSTOM_HANDWRITE_OP_FUNC_FILE
std::unordered_set<std::string> CUSTOM_HANDWRITE_OPS_SET = {"run_program"}; std::unordered_set<std::string> CUSTOM_HANDWRITE_OPS_SET = {"run_program"};
const char* CUSTOM_HANDWRITE_OP_FUNC_FILE =
"#include \"paddle/fluid/pybind/eager_custom_python_api.h\"\n";
// clang-format on // clang-format on
static inline bool FindInsMap(const std::string& op_type, static inline bool FindInsMap(const std::string& op_type,
...@@ -413,7 +411,6 @@ GenerateOpFunctions() { ...@@ -413,7 +411,6 @@ GenerateOpFunctions() {
std::vector<std::string> op_function_list, bind_function_list; std::vector<std::string> op_function_list, bind_function_list;
auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels(); auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();
bool append_custom_head_file = false;
for (auto& pair : op_info_map) { for (auto& pair : op_info_map) {
auto& op_info = pair.second; auto& op_info = pair.second;
auto op_proto = op_info.proto_; auto op_proto = op_info.proto_;
...@@ -423,7 +420,6 @@ GenerateOpFunctions() { ...@@ -423,7 +420,6 @@ GenerateOpFunctions() {
auto& op_type = op_proto->type(); auto& op_type = op_proto->type();
// Skip operators that will be handwriten in CUSTOM_HANDWRITE_OP_FUNC_FILE. // Skip operators that will be handwriten in CUSTOM_HANDWRITE_OP_FUNC_FILE.
if (CUSTOM_HANDWRITE_OPS_SET.count(op_type)) { if (CUSTOM_HANDWRITE_OPS_SET.count(op_type)) {
append_custom_head_file = true;
continue; continue;
} }
// Skip operator which is not inherit form OperatorWithKernel, like while, // Skip operator which is not inherit form OperatorWithKernel, like while,
...@@ -480,9 +476,7 @@ GenerateOpFunctions() { ...@@ -480,9 +476,7 @@ GenerateOpFunctions() {
bind_function_list.emplace_back(std::move(inplace_bind_function_str)); bind_function_list.emplace_back(std::move(inplace_bind_function_str));
} }
} }
if (append_custom_head_file) {
op_function_list.emplace_back(CUSTOM_HANDWRITE_OP_FUNC_FILE);
}
return std::make_tuple(op_function_list, bind_function_list); return std::make_tuple(op_function_list, bind_function_list);
} }
...@@ -498,18 +492,19 @@ int main(int argc, char* argv[]) { ...@@ -498,18 +492,19 @@ int main(int argc, char* argv[]) {
#endif #endif
std::vector<std::string> headers{ std::vector<std::string> headers{
"\"pybind11/detail/common.h\"", "<Python.h>",
"\"paddle/fluid/pybind/eager_final_state_op_function_impl.h\"", "\"paddle/fluid/platform/enforce.h\"",
"\"paddle/fluid/pybind/op_function_common.h\"",
"\"paddle/fluid/eager/api/generated/fluid_generated/" "\"paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h\"", "dygraph_forward_api.h\"",
"\"paddle/fluid/pybind/eager_utils.h\"",
"\"paddle/fluid/platform/profiler/event_tracing.h\"",
"\"paddle/fluid/pybind/exception.h\"", "\"paddle/fluid/pybind/exception.h\"",
"<Python.h>"}; "\"paddle/fluid/pybind/op_function_common.h\"",
"\"paddle/fluid/pybind/eager_custom_python_api.h\"",
"\"paddle/fluid/pybind/eager.h\""};
std::ofstream out(argv[1], std::ios::out); std::ofstream out(argv[1], std::ios::out);
out << "#pragma once\n\n";
for (auto& header : headers) { for (auto& header : headers) {
out << "#include " + header + "\n"; out << "#include " + header + "\n";
} }
...@@ -542,22 +537,20 @@ int main(int argc, char* argv[]) { ...@@ -542,22 +537,20 @@ int main(int argc, char* argv[]) {
<< core_ops_infos_registry << "\n {nullptr,nullptr,0,nullptr}" << core_ops_infos_registry << "\n {nullptr,nullptr,0,nullptr}"
<< "};\n\n"; << "};\n\n";
out << "inline void BindEagerOpFunctions(pybind11::module *module) {\n" out << "void BindEagerOpFunctions(pybind11::module *module) {\n"
<< " InitOpsAttrTypeMap();\n" << " InitOpsAttrTypeMap();\n"
<< " auto m = module->def_submodule(\"ops\");\n" << " auto m = module->def_submodule(\"ops\");\n"
<< " if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n" << " if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n" "core.eager.ops failed!\"));\n"
<< " }\n\n" << " }\n\n"
<< " if (PyModule_AddFunctions(m.ptr(), EagerFinalStateMethods) < 0) {\n" << " if (PyModule_AddFunctions(m.ptr(), CustomEagerMethods) < "
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n"
<< " }\n\n"
<< " if (PyModule_AddFunctions(m.ptr(), CustomEagerFinalStateMethods) < "
"0) {\n" "0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n" "core.eager.ops failed!\"));\n"
<< " }\n\n" << " }\n\n"
<< " BindFinalStateEagerOpFunctions(&m);\n\n"
<< "}\n\n" << "}\n\n"
<< "} // namespace pybind\n" << "} // namespace pybind\n"
<< "} // namespace paddle\n"; << "} // namespace paddle\n";
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
if __name__ == "__main__":
assert len(sys.argv) == 2
pybind_dir = sys.argv[1]
empty_files = [os.path.join(pybind_dir, "eager_final_state_op_function.cc")]
empty_files.append(os.path.join(pybind_dir, "eager_op_function.cc"))
empty_files.append(os.path.join(pybind_dir, "op_function.cc"))
for path in empty_files:
if not os.path.exists(path):
open(path, 'a').close()
...@@ -257,8 +257,7 @@ PyObject* MakeReturnPyObject(const std::tuple<Args...>& out) { ...@@ -257,8 +257,7 @@ PyObject* MakeReturnPyObject(const std::tuple<Args...>& out) {
return result; return result;
} }
void BindOpFunctions(pybind11::module* module);
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
// This include must be the last line
#include "paddle/fluid/pybind/op_function_impl.h"
...@@ -506,13 +506,15 @@ int main(int argc, char* argv[]) { ...@@ -506,13 +506,15 @@ int main(int argc, char* argv[]) {
std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\"", std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\"",
"\"paddle/fluid/platform/profiler.h\"", "\"paddle/fluid/platform/profiler.h\"",
"\"pybind11/numpy.h\"",
"\"pybind11/pybind11.h\"",
"\"pybind11/detail/common.h\"", "\"pybind11/detail/common.h\"",
"\"paddle/fluid/pybind/eager_utils.h\"",
"\"paddle/fluid/pybind/op_function.h\"",
"<Python.h>"}; "<Python.h>"};
std::ofstream out(argv[1], std::ios::out); std::ofstream out(argv[1], std::ios::out);
out << "#pragma once\n\n";
for (auto& header : headers) { for (auto& header : headers) {
out << "#include " + header + "\n"; out << "#include " + header + "\n";
} }
...@@ -532,7 +534,7 @@ int main(int argc, char* argv[]) { ...@@ -532,7 +534,7 @@ int main(int argc, char* argv[]) {
<< "\n {nullptr,nullptr,0,nullptr}" << "\n {nullptr,nullptr,0,nullptr}"
<< "};\n\n"; << "};\n\n";
out << "inline void BindOpFunctions(pybind11::module *module) {\n" out << "void BindOpFunctions(pybind11::module *module) {\n"
<< " auto m = module->def_submodule(\"ops\");\n" << " auto m = module->def_submodule(\"ops\");\n"
<< " if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n" << " if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册