提交 cfe9f4c2 编写于 作者: M Megvii Engine Team 提交者: 黄信达

Revert "fix(imperative): fix the split dump problem"

This reverts commit 8380af7b464f169b58b1c52791bac0fc983b4f40.

GitOrigin-RevId: 062ed1ebb112b55ad52d6d5479c01fd26085dc5c
上级 a6872cf1
......@@ -400,6 +400,7 @@ py::object get_res_by_refhdl(
ref = py::reinterpret_borrow<py::object>(ref_hdl);
}
if (PyObject_TypeCheck(ref.ptr(), py_varnode_type)) {
auto temp = dtype.cast<mgb::DType>();
ComputingGraph* graph = getattr(ref, "graph").cast<ComputingGraph*>();
cg::VarNode* node = getattr(ref, "var").cast<cg::VarNode*>();
CompNode cn;
......@@ -1472,23 +1473,8 @@ py::object _split_cpp(
std::to_string(axis) + " cannot be split into " +
std::to_string(n_sections) + " sections");
}
if (enable_fastpath(inp_hdl)) {
op = Split::make(axis, n_sections);
p.resize(2);
} else {
size_t n_total_ = n_total.cast<int>();
for (size_t i = 0; i < n_sections; ++i) {
auto section_size = (n_total_ + n_sections - i - 1) / n_sections;
partitions.append(_Const(
py::int_(section_size), py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device")));
}
op = Split::make(axis, 0);
p.resize(partitions.size() + 2);
for (size_t i = 0; i < partitions.size(); ++i) {
p[i + 2] = partitions[i].ptr();
}
}
op = Split::make(axis, n_sections);
p.resize(2);
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册