提交 380cb6e4 编写于 作者: M Megvii Engine Team

feat(mge/jit): add support output symbol var name settings for dump

GitOrigin-RevId: 258c03ee34ac7cd120372e57152c95d613179bab
上级 e1e56988
......@@ -235,14 +235,35 @@ class Tensor:
return self.__val.dtype
return self._symvar.dtype
def set_dtype(self, dtype: str = None):
@dtype.setter
def dtype(self, dtype: str = None):
r"""Set the data type of the tensor.
"""
if self.__val is not None:
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
elif self.__sym_override is not None:
self.__sym_override = self.__sym_override.astype(dtype)
elif self.__sym is not None:
self.__sym = self.__sym.astype(dtype)
@property
def name(self):
r"""Get the tensor name, does not support Parameter and Buffer.
"""
return self._symvar.name
@name.setter
def name(self, name: str = None):
r"""Set the tensor name, does not support Parameter and Buffer.
"""
if self.__val is not None:
raise ValueError("name setting is not available for Parameter or Buffer.")
if self.__sym_override is not None:
self.__sym_override = self.__sym_override.rename(name)
if self.__sym is not None:
assert not self.__val
self.__sym = self.__sym.rename(name)
@property
def _comp_node(self):
if self.__val is not None:
......
......@@ -436,6 +436,7 @@ class trace:
arg_names=None,
append=False,
optimize_for_inference=False,
output_names=None,
**kwargs
):
"""
......@@ -446,6 +447,8 @@ class trace:
:param append: whether output is appended to ``fpath``.
:param optimize_for_inference: whether to enable optimize_for_inference
pass before dump.
:param output_names: names of the output tensors in the traced function,
will use the default name if does not specify.
:param enable_io16xc32: whether to use float16 for I/O between oprs and use
float32 as internal computation precision. Note the output var would be
......@@ -488,6 +491,17 @@ class trace:
len(self._args), len(arg_names)
)
)
if isinstance(output_names, str):
output_names = [output_names]
if output_names is None:
output_names = [var.name for var in self._sym_outputs]
elif len(output_names) != len(self._sym_outputs):
raise ValueError(
"len(output_names) should be {}, got {}".format(
len(self._sym_outputs), len(output_names)
)
)
optimize_for_inference_args_map = {
"enable_io16xc32": "f16_io_f32_comp",
"enable_ioc16": "f16_io_comp",
......@@ -541,6 +555,8 @@ class trace:
sym_outputs = mgb.optimize_for_inference(
sym_outputs, **optimize_for_inference_kwargs
)
for var, name in zip(sym_outputs, output_names):
var.rename(name)
mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append)
def get_profile(self):
......
......@@ -464,7 +464,7 @@ class Module(metaclass=ABCMeta):
# For quantized dtype, the initialized dtype
# scale/zero_points maybe invalid, use pretrained dtype instead.
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
var.set_dtype(to_be_load.dtype)
var.dtype = to_be_load.dtype
var.set_value(to_be_load)
loaded.append(k)
......
......@@ -46,29 +46,46 @@ def test_tensor_set_dtype():
)
t = mge.Parameter(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
t.dtype = mgb.dtype.qint8(0.1)
check_dtype_value(t, 0.1, 10)
t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
t.dtype = mgb.dtype.qint8(0.3)
check_dtype_value(t, 0.3, 3)
t = mge.Buffer(np.ones((3, 4), dtype="float32"))
t.set_dtype(mgb.dtype.qint8(0.1))
t.dtype = mgb.dtype.qint8(0.1)
check_dtype_value(t, 0.1, 10)
t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1)))
t.set_dtype(mgb.dtype.qint8(0.3))
t.dtype = mgb.dtype.qint8(0.3)
check_dtype_value(t, 0.3, 3)
t = mge.Buffer(np.ones((3, 4), dtype="float32"))
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.2))
s.dtype = mgb.dtype.qint8(0.2)
check_dtype_value(s, 0.2, 10)
t.set_dtype(mgb.dtype.qint8(0.3))
t.dtype = mgb.dtype.qint8(0.3)
s = t + 1
s.set_dtype(mgb.dtype.qint8(0.1))
s.dtype = mgb.dtype.qint8(0.1)
check_dtype_value(s, 0.1, 18)
s.set_dtype("float32")
s.dtype = "float32"
check_dtype_value(s, 0, 1.8)
def test_tensor_name():
p = mge.Parameter(np.ones((3, 4), dtype="float32"))
assert "shared" in p.name
with pytest.raises(ValueError):
p.name = "Parameter0"
b = mge.Buffer(np.ones((3, 4), dtype="float32"))
assert "shared" in b.name
with pytest.raises(ValueError):
b.name = "Buffer0"
s = b + 1
assert "ADD" in s.name
s.name = "WeightAdd1"
assert s.name == "WeightAdd1"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册