提交 e474994f 编写于 作者: M Megvii Engine Team

feat(imperative/jit): catch input tensors name when tracing

GitOrigin-RevId: 9c692548663654265f9f9e2753f8637d444cb78d
上级 aed681d3
......@@ -772,7 +772,8 @@ class trace:
len(self._output_bindings)
)
)
if arg_names is None:
without_arg_names = arg_names is None
if without_arg_names:
arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
if arg_names and not isinstance(arg_names, collections.abc.Sequence):
arg_names = (arg_names,)
......@@ -802,7 +803,7 @@ class trace:
dtype=info.dtype,
device=dumped_device(info),
shape=info.shape or (1,),
name=arg_names[i] if arg_names else None,
name=info.name if without_arg_names and info.name else arg_names[i],
)
for k, h in self._kwarg_bindings.items():
info = self._tinfo[h]
......@@ -889,6 +890,7 @@ class trace:
return
h, info = self._new_handle()
info.external = False
info.name = x.c_name
info.device = x.device
info.dtype = x.dtype
info.shape = x.numpy().shape
......
......@@ -203,14 +203,31 @@ def test_with_same_operators(symbolic):
assert ops[-2].name == "simple.RELU[0]"
def test_not_keep_opr_name():
@pytest.mark.parametrize("symbolic", [False, True])
def test_not_keep_opr_name(symbolic):
def f(x):
return 2 * x
op = _dump_and_load(f, True, False)[-1]
op = _dump_and_load(f, symbolic, False)[-1]
assert op.name == "MUL(x,const<2>[2])[4]"
@pytest.mark.parametrize("tensor_name, var_name", [("data", "data"), (None, "arg_0")])
def test_catch_input_name(tensor_name, var_name):
def f(x):
return 2 * x
func = trace(f, symbolic=True, capture_as_const=True)
x = Tensor(np.ones(shape=(2, 3)), name=tensor_name)
func(x).numpy()
file = io.BytesIO()
func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2)
file.seek(0)
*_, outputs = G.load_graph(file)
op = cgtools.get_oprs_seq(outputs)[-1]
assert op.inputs[0].name == var_name
@pytest.mark.parametrize("symbolic", [False, True])
def test_quantized_module_auto_naming(symbolic):
class Simple(M.Module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册