提交 54633b92 编写于 作者: M Megvii Engine Team 提交者: 黄信达

fix(imperative): fix the split dump problem

GitOrigin-RevId: 0a0265e59819a89b12853229eff9d14d1e55ace6
上级 6da3de19
......@@ -400,7 +400,6 @@ py::object get_res_by_refhdl(
ref = py::reinterpret_borrow<py::object>(ref_hdl);
}
if (PyObject_TypeCheck(ref.ptr(), py_varnode_type)) {
auto temp = dtype.cast<mgb::DType>();
ComputingGraph* graph = getattr(ref, "graph").cast<ComputingGraph*>();
cg::VarNode* node = getattr(ref, "var").cast<cg::VarNode*>();
CompNode cn;
......@@ -1473,8 +1472,23 @@ py::object _split_cpp(
std::to_string(axis) + " cannot be split into " +
std::to_string(n_sections) + " sections");
}
op = Split::make(axis, n_sections);
p.resize(2);
if (enable_fastpath(inp_hdl)) {
op = Split::make(axis, n_sections);
p.resize(2);
} else {
size_t n_total_ = n_total.cast<int>();
for (size_t idx = 0; idx < n_sections; ++idx) {
auto section_size = (n_total_ + n_sections - idx - 1) / n_sections;
partitions.append(_Const(
py::int_(section_size), py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device")));
}
op = Split::make(axis, 0);
p.resize(partitions.size() + 2);
for (size_t idx = 0; idx < partitions.size(); ++idx) {
p[idx + 2] = partitions[idx].ptr();
}
}
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
......
......@@ -308,6 +308,28 @@ def test_dump_with_testcase():
f.dump(file, input_data=["#rand(0, 255, 1)"])
def test_split_dump():
class SimpleNet(Module):
def __init__(self, num_segments: int = 3):
super().__init__()
self.num_segments = num_segments
def forward(self, x):
x = F.split(x, self.num_segments, axis=1)
return x
model = SimpleNet()
model.eval()
data = tensor(np.random.random((1, 12, 224, 224)))
@trace(symbolic=True, capture_as_const=True)
def fun(data, *, net):
return net(data)
x = fun(data, net=model)
fun.dump(io.BytesIO(), arg_names=["data"])
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_profiler(trace_mode):
@trace(symbolic=trace_mode, profiling=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册