未验证 提交 2fc429f1 编写于 作者: Z zyfncg 提交者: GitHub

[CINN] Remove some pybind interface in cinn to fix compile problem (#55043)

* remove some pybind interface in cinn to fix compile problem

* modify cmake

* fix cmake

* add log for build cinn whl

* fix ninja for cinn

* fix conflict
上级 31edad21
...@@ -24,7 +24,7 @@ set(SOURCE_INCLUDE_DIR ${SOURCE_DIR}/include) ...@@ -24,7 +24,7 @@ set(SOURCE_INCLUDE_DIR ${SOURCE_DIR}/include)
include_directories(${PYBIND_INCLUDE_DIR}) include_directories(${PYBIND_INCLUDE_DIR})
set(PYBIND_PATCH_COMMAND "") set(PYBIND_PATCH_COMMAND "")
if(NOT WIN32 AND NOT CINN_ONLY) if(NOT WIN32)
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/pybind/cast.h.patch file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/pybind/cast.h.patch
native_dst) native_dst)
# Note: [Why calling some `git` commands before `patch`?] # Note: [Why calling some `git` commands before `patch`?]
......
...@@ -12,9 +12,7 @@ add_subdirectory(backends) ...@@ -12,9 +12,7 @@ add_subdirectory(backends)
add_subdirectory(lang) add_subdirectory(lang)
add_subdirectory(optim) add_subdirectory(optim)
add_subdirectory(hlir) add_subdirectory(hlir)
if(CINN_ONLY) add_subdirectory(pybind)
add_subdirectory(pybind)
endif()
add_subdirectory(frontend) add_subdirectory(frontend)
# Download a model # Download a model
......
...@@ -179,23 +179,8 @@ void DefineUnaryOpNode(py::module *m, absl::string_view node_name) { ...@@ -179,23 +179,8 @@ void DefineUnaryOpNode(py::module *m, absl::string_view node_name) {
py::return_value_policy::reference); py::return_value_policy::reference);
} }
class ObjectWrapper : public Object {
public:
using Object::Object;
const char *type_info() const override {
PYBIND11_OVERLOAD_PURE(const char *, Object, type_info);
}
};
class IrNodeWrapper : ir::IrNode { class IrNodeWrapper : ir::IrNode {
using ir::IrNode::IrNode; using ir::IrNode::IrNode;
}; };
class _Operation_Wrapper : ir::_Operation_ {
public:
const char *func_type() const override {
PYBIND11_OVERLOAD_PURE(const char *, ir::_Operation_, func_type);
}
};
} // namespace cinn::pybind } // namespace cinn::pybind
...@@ -30,7 +30,6 @@ namespace cinn::pybind { ...@@ -30,7 +30,6 @@ namespace cinn::pybind {
using common::bfloat16; using common::bfloat16;
using common::CINNValue; using common::CINNValue;
using common::float16; using common::float16;
using common::Object;
using common::Target; using common::Target;
using common::Type; using common::Type;
using utils::GetStreamCnt; using utils::GetStreamCnt;
...@@ -39,7 +38,6 @@ using utils::StringFormat; ...@@ -39,7 +38,6 @@ using utils::StringFormat;
namespace { namespace {
void BindTarget(py::module *); void BindTarget(py::module *);
void BindType(py::module *); void BindType(py::module *);
void BindObject(py::module *);
void BindShared(py::module *); void BindShared(py::module *);
void BindCinnValue(py::module *); void BindCinnValue(py::module *);
...@@ -208,11 +206,6 @@ void BindType(py::module *m) { ...@@ -208,11 +206,6 @@ void BindType(py::module *m) {
}); });
} }
void BindObject(py::module *m) {
py::class_<Object, ObjectWrapper> object(*m, "Object");
object.def("type_info", &Object::type_info);
}
void BindShared(py::module *m) { void BindShared(py::module *m) {
py::class_<common::RefCount> ref_count(*m, "RefCount"); py::class_<common::RefCount> ref_count(*m, "RefCount");
ref_count.def(py::init<>()) ref_count.def(py::init<>())
...@@ -367,7 +360,6 @@ void BindCinnValue(py::module *m) { ...@@ -367,7 +360,6 @@ void BindCinnValue(py::module *m) {
void BindCommon(py::module *m) { void BindCommon(py::module *m) {
BindTarget(m); BindTarget(m);
BindType(m); BindType(m);
BindObject(m);
BindShared(m); BindShared(m);
BindCinnValue(m); BindCinnValue(m);
} }
......
...@@ -108,7 +108,7 @@ void BindNode(py::module *m) { ...@@ -108,7 +108,7 @@ void BindNode(py::module *m) {
#undef DECLARE_IR_NODE_TY #undef DECLARE_IR_NODE_TY
// class IrNode // class IrNode
py::class_<ir::IrNode, IrNodeWrapper /*, ObjectWrapper*/> ir_node( py::class_<ir::IrNode, IrNodeWrapper> ir_node(
*m, "IrNode", py::module_local()); *m, "IrNode", py::module_local());
ir_node.def(py::init<>()) ir_node.def(py::init<>())
.def(py::init<ir::Type>()) .def(py::init<ir::Type>())
...@@ -539,14 +539,13 @@ void BindIrIr(py::module *m) { ...@@ -539,14 +539,13 @@ void BindIrIr(py::module *m) {
} }
void BindOperation(py::module *m) { void BindOperation(py::module *m) {
py::class_<ir::PlaceholderOp /*, _Operation_Wrapper*/> placeholder_op( py::class_<ir::PlaceholderOp> placeholder_op(*m, "PlaceholderOp");
*m, "PlaceholderOp");
placeholder_op.def_readwrite("shape", &ir::PlaceholderOp::shape) placeholder_op.def_readwrite("shape", &ir::PlaceholderOp::shape)
.def_readwrite("dtype", &ir::PlaceholderOp::dtype) .def_readwrite("dtype", &ir::PlaceholderOp::dtype)
.def_static("make", &ir::PlaceholderOp::Make) .def_static("make", &ir::PlaceholderOp::Make)
.def("func_type", &ir::PlaceholderOp::func_type); .def("func_type", &ir::PlaceholderOp::func_type);
py::class_<ir::CallOp /*, _Operation_Wrapper*/> call_op(*m, "CallOp"); py::class_<ir::CallOp> call_op(*m, "CallOp");
call_op.def("target", &ir::CallOp::target) call_op.def("target", &ir::CallOp::target)
.def_readwrite("call_expr", &ir::CallOp::call_expr) .def_readwrite("call_expr", &ir::CallOp::call_expr)
.def("read_args_mutable", py::overload_cast<>(&ir::CallOp::read_args)) .def("read_args_mutable", py::overload_cast<>(&ir::CallOp::read_args))
...@@ -564,8 +563,7 @@ void BindOperation(py::module *m) { ...@@ -564,8 +563,7 @@ void BindOperation(py::module *m) {
.def_static("make", &ir::CallOp::Make) .def_static("make", &ir::CallOp::Make)
.def("func_type", &ir::CallOp::func_type); .def("func_type", &ir::CallOp::func_type);
py::class_<ir::ComputeOp /*, _Operation_Wrapper*/> compute_op(*m, py::class_<ir::ComputeOp> compute_op(*m, "ComputeOp");
"ComputeOp");
compute_op.def_readwrite("reduce_axis", &ir::ComputeOp::reduce_axis) compute_op.def_readwrite("reduce_axis", &ir::ComputeOp::reduce_axis)
.def_readwrite("shape", &ir::ComputeOp::shape) .def_readwrite("shape", &ir::ComputeOp::shape)
.def_readwrite("body", &ir::ComputeOp::body) .def_readwrite("body", &ir::ComputeOp::body)
......
...@@ -69,7 +69,7 @@ void BindStageMap(py::module *m) { ...@@ -69,7 +69,7 @@ void BindStageMap(py::module *m) {
} }
void BindStage(py::module *m) { void BindStage(py::module *m) {
py::class_<Stage, common::Object> stage(*m, "Stage"); py::class_<Stage> stage(*m, "Stage");
// enum Stage::ComputeAtKind // enum Stage::ComputeAtKind
py::enum_<Stage::ComputeAtKind> compute_at_kind(stage, "ComputeAtKind"); py::enum_<Stage::ComputeAtKind> compute_at_kind(stage, "ComputeAtKind");
compute_at_kind.value("kComputeAtUnk", Stage::ComputeAtKind::kComputeAtAuto) compute_at_kind.value("kComputeAtUnk", Stage::ComputeAtKind::kComputeAtAuto)
......
if(CINN_ONLY) if(WITH_CINN)
file(GLOB_RECURSE CINN_PY_FILES ${PROJECT_SOURCE_DIR}/python/cinn/*.py) file(GLOB_RECURSE CINN_PY_FILES ${PROJECT_SOURCE_DIR}/python/cinn/*.py)
set(CINN_PYTHON_DIR ${PROJECT_SOURCE_DIR}/python/cinn) set(CINN_PYTHON_DIR ${PROJECT_SOURCE_DIR}/python/cinn)
set(CINN_CORE_API ${CMAKE_BINARY_DIR}/python/cinn/core_api.so) set(CINN_CORE_API ${CMAKE_BINARY_DIR}/python/cinn/core_api.so)
...@@ -24,8 +24,8 @@ if(CINN_ONLY) ...@@ -24,8 +24,8 @@ if(CINN_ONLY)
# then core_api.so. # then core_api.so.
add_custom_command( add_custom_command(
OUTPUT ${CINN_CORE_API} POST_BUILD OUTPUT ${CINN_CORE_API} POST_BUILD
COMMAND cp -rf --remove-destination ${CINN_PYTHON_DIR} COMMAND ${CMAKE_COMMAND} -E copy_directory ${CINN_PYTHON_DIR}
${CMAKE_BINARY_DIR}/python/cinn ${PADDLE_BINARY_DIR}/python/cinn
COMMAND cp --remove-destination COMMAND cp --remove-destination
${CMAKE_BINARY_DIR}/paddle/cinn/pybind/core_api.so ${CINN_CORE_API} ${CMAKE_BINARY_DIR}/paddle/cinn/pybind/core_api.so ${CINN_CORE_API}
COMMAND cd ${CMAKE_CURRENT_BINARY_DIR} && ${PYTHON_EXECUTABLE} setup_cinn.py COMMAND cd ${CMAKE_CURRENT_BINARY_DIR} && ${PYTHON_EXECUTABLE} setup_cinn.py
...@@ -35,7 +35,9 @@ if(CINN_ONLY) ...@@ -35,7 +35,9 @@ if(CINN_ONLY)
add_custom_target(COPY_CINN_CORE_API ALL DEPENDS ${CINN_CORE_API} add_custom_target(COPY_CINN_CORE_API ALL DEPENDS ${CINN_CORE_API}
${CINN_PY_FILES}) ${CINN_PY_FILES})
return() if(CINN_ONLY)
return()
endif()
endif() endif()
file(GLOB UTILS_PY_FILES . ./paddle/legacy/utils/*.py) file(GLOB UTILS_PY_FILES . ./paddle/legacy/utils/*.py)
......
...@@ -127,6 +127,7 @@ else: ...@@ -127,6 +127,7 @@ else:
yield yield
libs_path = '${CMAKE_BINARY_DIR}/python/cinn/libs' libs_path = '${CMAKE_BINARY_DIR}/python/cinn/libs'
os.makedirs(libs_path, exist_ok=True)
cinnlibs = [] cinnlibs = []
package_data = {'cinn': ['core_api.so'], 'cinn.libs': []} package_data = {'cinn': ['core_api.so'], 'cinn.libs': []}
......
...@@ -93,11 +93,12 @@ endfunction() ...@@ -93,11 +93,12 @@ endfunction()
if(WITH_TESTING) if(WITH_TESTING)
if(WITH_CINN) if(WITH_CINN)
add_subdirectory(cpp/cinn) add_subdirectory(cpp/cinn)
add_subdirectory(cinn)
endif() endif()
if(CINN_ONLY) if(CINN_ONLY)
add_subdirectory(cinn)
return() return()
endif() endif()
add_subdirectory(amp) add_subdirectory(amp)
add_subdirectory(asp) add_subdirectory(asp)
add_subdirectory(autograd) add_subdirectory(autograd)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册