提交 9e6544bf 编写于 作者: M Megvii Engine Team

fix(trace): fix name duplication and fix error message for invalid input

GitOrigin-RevId: 7fe1605c2639ba2c67488e0bffd6d0e2fab73e6a
上级 55cdda79
...@@ -216,6 +216,9 @@ class trace: ...@@ -216,6 +216,9 @@ class trace:
def _process_inputs(self, *args, **kwargs): def _process_inputs(self, *args, **kwargs):
for i, arg in enumerate(args): for i, arg in enumerate(args):
assert isinstance(
arg, RawTensor
), "Only support tensor type args when capture_as_const is enabled"
name_tensor("arg_{}".format(i), arg) name_tensor("arg_{}".format(i), arg)
# TODO: mark kwargs in order # TODO: mark kwargs in order
......
...@@ -1229,6 +1229,7 @@ void init_tensor(py::module m) { ...@@ -1229,6 +1229,7 @@ void init_tensor(py::module m) {
m.def("name_tensor", [](std::string name, py::object tensor) { m.def("name_tensor", [](std::string name, py::object tensor) {
auto* tw = TensorWrapper::try_cast(tensor.ptr()); auto* tw = TensorWrapper::try_cast(tensor.ptr());
mgb_assert(tw, "Arg_1 shoud be Tensor!");
auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
tw->m_tensor->reset(output); tw->m_tensor->reset(output);
}); });
......
...@@ -748,3 +748,37 @@ def test_trace_jit_config(): ...@@ -748,3 +748,37 @@ def test_trace_jit_config():
for fuse_dimshuffle in [None, False, True]: for fuse_dimshuffle in [None, False, True]:
for fuse_reduce in [None, False, True]: for fuse_reduce in [None, False, True]:
run(fuse_dimshuffle, fuse_reduce) run(fuse_dimshuffle, fuse_reduce)
def test_trace_naming():
@trace(symbolic=True, capture_as_const=True)
def func(x):
return F.max(x, axis=2, keepdims=False) + 1
inp = tensor(np.random.random((1, 3, 3, 3)))
func(inp)
file = io.BytesIO()
func.dump(file, optimize_for_inference=False)
file.seek(0)
import megengine.utils.network as network
net = network.Network.load(file)
names = set()
for var in net.all_vars:
assert var.name not in names
names.add(var.name)
def test_invalid_inp_error():
@trace(capture_as_const=True)
def func(a):
return a * 2
try:
func(1)
except Exception as e:
assert (
str(e) == "Only support tensor type args when capture_as_const is enabled"
)
else:
assert False
...@@ -98,8 +98,6 @@ VarNodeArray TraceResult::dump( ...@@ -98,8 +98,6 @@ VarNodeArray TraceResult::dump(
"do model.eval()?"); "do model.eval()?");
} }
output_nodes = OpDef::apply_on_var_node(*op, input_nodes); output_nodes = OpDef::apply_on_var_node(*op, input_nodes);
name2ops[output_nodes[0]->owner_opr()->name()].push_back(
output_nodes[0]->owner_opr());
} else { } else {
// no opr, just forward VarNode // no opr, just forward VarNode
mgb_assert( mgb_assert(
...@@ -121,6 +119,13 @@ VarNodeArray TraceResult::dump( ...@@ -121,6 +119,13 @@ VarNodeArray TraceResult::dump(
} }
} }
} }
auto on_opr = [&name2ops](cg::OperatorNodeBase* opr) {
name2ops[opr->name()].push_back(opr);
};
cg::DepOprIter dep_iter(on_opr);
for (auto&& [output, name] : outputs) {
dep_iter.add(nodes[output]->owner_opr());
}
for (auto&& [name, ops] : name2ops) { for (auto&& [name, ops] : name2ops) {
if (ops.size() <= 1) { if (ops.size() <= 1) {
continue; continue;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册