未验证 提交 2fc429f1 编写于 作者: Z zyfncg 提交者: GitHub

[CINN] Remove some pybind interface in cinn to fix compile problem (#55043)

* remove some pybind interface in cinn to fix compile problem

* modify cmake

* fix cmake

* add log for build cinn whl

* fix ninja for cinn

* fix conflict
上级 31edad21
......@@ -24,7 +24,7 @@ set(SOURCE_INCLUDE_DIR ${SOURCE_DIR}/include)
include_directories(${PYBIND_INCLUDE_DIR})
set(PYBIND_PATCH_COMMAND "")
if(NOT WIN32 AND NOT CINN_ONLY)
if(NOT WIN32)
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/pybind/cast.h.patch
native_dst)
# Note: [Why calling some `git` commands before `patch`?]
......
......@@ -12,9 +12,7 @@ add_subdirectory(backends)
add_subdirectory(lang)
add_subdirectory(optim)
add_subdirectory(hlir)
if(CINN_ONLY)
add_subdirectory(pybind)
endif()
add_subdirectory(pybind)
add_subdirectory(frontend)
# Download a model
......
......@@ -179,23 +179,8 @@ void DefineUnaryOpNode(py::module *m, absl::string_view node_name) {
py::return_value_policy::reference);
}
class ObjectWrapper : public Object {
public:
using Object::Object;
const char *type_info() const override {
PYBIND11_OVERLOAD_PURE(const char *, Object, type_info);
}
};
class IrNodeWrapper : ir::IrNode {
using ir::IrNode::IrNode;
};
class _Operation_Wrapper : ir::_Operation_ {
public:
const char *func_type() const override {
PYBIND11_OVERLOAD_PURE(const char *, ir::_Operation_, func_type);
}
};
} // namespace cinn::pybind
......@@ -30,7 +30,6 @@ namespace cinn::pybind {
using common::bfloat16;
using common::CINNValue;
using common::float16;
using common::Object;
using common::Target;
using common::Type;
using utils::GetStreamCnt;
......@@ -39,7 +38,6 @@ using utils::StringFormat;
namespace {
void BindTarget(py::module *);
void BindType(py::module *);
void BindObject(py::module *);
void BindShared(py::module *);
void BindCinnValue(py::module *);
......@@ -208,11 +206,6 @@ void BindType(py::module *m) {
});
}
void BindObject(py::module *m) {
py::class_<Object, ObjectWrapper> object(*m, "Object");
object.def("type_info", &Object::type_info);
}
void BindShared(py::module *m) {
py::class_<common::RefCount> ref_count(*m, "RefCount");
ref_count.def(py::init<>())
......@@ -367,7 +360,6 @@ void BindCinnValue(py::module *m) {
void BindCommon(py::module *m) {
BindTarget(m);
BindType(m);
BindObject(m);
BindShared(m);
BindCinnValue(m);
}
......
......@@ -108,7 +108,7 @@ void BindNode(py::module *m) {
#undef DECLARE_IR_NODE_TY
// class IrNode
py::class_<ir::IrNode, IrNodeWrapper /*, ObjectWrapper*/> ir_node(
py::class_<ir::IrNode, IrNodeWrapper> ir_node(
*m, "IrNode", py::module_local());
ir_node.def(py::init<>())
.def(py::init<ir::Type>())
......@@ -539,14 +539,13 @@ void BindIrIr(py::module *m) {
}
void BindOperation(py::module *m) {
py::class_<ir::PlaceholderOp /*, _Operation_Wrapper*/> placeholder_op(
*m, "PlaceholderOp");
py::class_<ir::PlaceholderOp> placeholder_op(*m, "PlaceholderOp");
placeholder_op.def_readwrite("shape", &ir::PlaceholderOp::shape)
.def_readwrite("dtype", &ir::PlaceholderOp::dtype)
.def_static("make", &ir::PlaceholderOp::Make)
.def("func_type", &ir::PlaceholderOp::func_type);
py::class_<ir::CallOp /*, _Operation_Wrapper*/> call_op(*m, "CallOp");
py::class_<ir::CallOp> call_op(*m, "CallOp");
call_op.def("target", &ir::CallOp::target)
.def_readwrite("call_expr", &ir::CallOp::call_expr)
.def("read_args_mutable", py::overload_cast<>(&ir::CallOp::read_args))
......@@ -564,8 +563,7 @@ void BindOperation(py::module *m) {
.def_static("make", &ir::CallOp::Make)
.def("func_type", &ir::CallOp::func_type);
py::class_<ir::ComputeOp /*, _Operation_Wrapper*/> compute_op(*m,
"ComputeOp");
py::class_<ir::ComputeOp> compute_op(*m, "ComputeOp");
compute_op.def_readwrite("reduce_axis", &ir::ComputeOp::reduce_axis)
.def_readwrite("shape", &ir::ComputeOp::shape)
.def_readwrite("body", &ir::ComputeOp::body)
......
......@@ -69,7 +69,7 @@ void BindStageMap(py::module *m) {
}
void BindStage(py::module *m) {
py::class_<Stage, common::Object> stage(*m, "Stage");
py::class_<Stage> stage(*m, "Stage");
// enum Stage::ComputeAtKind
py::enum_<Stage::ComputeAtKind> compute_at_kind(stage, "ComputeAtKind");
compute_at_kind.value("kComputeAtUnk", Stage::ComputeAtKind::kComputeAtAuto)
......
if(CINN_ONLY)
if(WITH_CINN)
file(GLOB_RECURSE CINN_PY_FILES ${PROJECT_SOURCE_DIR}/python/cinn/*.py)
set(CINN_PYTHON_DIR ${PROJECT_SOURCE_DIR}/python/cinn)
set(CINN_CORE_API ${CMAKE_BINARY_DIR}/python/cinn/core_api.so)
......@@ -24,8 +24,8 @@ if(CINN_ONLY)
# then core_api.so.
add_custom_command(
OUTPUT ${CINN_CORE_API} POST_BUILD
COMMAND cp -rf --remove-destination ${CINN_PYTHON_DIR}
${CMAKE_BINARY_DIR}/python/cinn
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CINN_PYTHON_DIR}
${PADDLE_BINARY_DIR}/python/cinn
COMMAND cp --remove-destination
${CMAKE_BINARY_DIR}/paddle/cinn/pybind/core_api.so ${CINN_CORE_API}
COMMAND cd ${CMAKE_CURRENT_BINARY_DIR} && ${PYTHON_EXECUTABLE} setup_cinn.py
......@@ -35,7 +35,9 @@ if(CINN_ONLY)
add_custom_target(COPY_CINN_CORE_API ALL DEPENDS ${CINN_CORE_API}
${CINN_PY_FILES})
return()
if(CINN_ONLY)
return()
endif()
endif()
file(GLOB UTILS_PY_FILES . ./paddle/legacy/utils/*.py)
......
......@@ -127,6 +127,7 @@ else:
yield
libs_path = '${CMAKE_BINARY_DIR}/python/cinn/libs'
os.makedirs(libs_path, exist_ok=True)
cinnlibs = []
package_data = {'cinn': ['core_api.so'], 'cinn.libs': []}
......
......@@ -93,11 +93,12 @@ endfunction()
if(WITH_TESTING)
if(WITH_CINN)
add_subdirectory(cpp/cinn)
add_subdirectory(cinn)
endif()
if(CINN_ONLY)
add_subdirectory(cinn)
return()
endif()
add_subdirectory(amp)
add_subdirectory(asp)
add_subdirectory(autograd)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册