提交 d026ba4f 编写于 作者: M Megvii Engine Team

fix(imperative): fix the split dump problem

GitOrigin-RevId: 8380af7b464f169b58b1c52791bac0fc983b4f40
上级 803bb79f
......@@ -400,7 +400,6 @@ py::object get_res_by_refhdl(
ref = py::reinterpret_borrow<py::object>(ref_hdl);
}
if (PyObject_TypeCheck(ref.ptr(), py_varnode_type)) {
auto temp = dtype.cast<mgb::DType>();
ComputingGraph* graph = getattr(ref, "graph").cast<ComputingGraph*>();
cg::VarNode* node = getattr(ref, "var").cast<cg::VarNode*>();
CompNode cn;
......@@ -1473,8 +1472,23 @@ py::object _split_cpp(
std::to_string(axis) + " cannot be split into " +
std::to_string(n_sections) + " sections");
}
op = Split::make(axis, n_sections);
p.resize(2);
if (enable_fastpath(inp_hdl)) {
op = Split::make(axis, n_sections);
p.resize(2);
} else {
size_t n_total_ = n_total.cast<int>();
for (size_t i = 0; i < n_sections; ++i) {
auto section_size = (n_total_ + n_sections - i - 1) / n_sections;
partitions.append(_Const(
py::int_(section_size), py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device")));
}
op = Split::make(axis, 0);
p.resize(partitions.size() + 2);
for (size_t i = 0; i < partitions.size(); ++i) {
p[i + 2] = partitions[i].ptr();
}
}
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册