From 380cb6e47f27936d4d1ffb1b70f9ca2b45be1195 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 23 Jul 2020 16:44:30 +0800 Subject: [PATCH] feat(mge/jit): add support output symbol var name settings for dump GitOrigin-RevId: 258c03ee34ac7cd120372e57152c95d613179bab --- python_module/megengine/core/tensor.py | 23 +++++++++++++- python_module/megengine/jit/__init__.py | 16 ++++++++++ python_module/megengine/module/module.py | 2 +- python_module/test/unit/core/test_tensor.py | 33 ++++++++++++++++----- 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index b8d6e1c9..a5ef8a0a 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -235,14 +235,35 @@ class Tensor: return self.__val.dtype return self._symvar.dtype - def set_dtype(self, dtype: str = None): + @dtype.setter + def dtype(self, dtype: str = None): r"""Set the data type of the tensor. """ if self.__val is not None: self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) + elif self.__sym_override is not None: + self.__sym_override = self.__sym_override.astype(dtype) elif self.__sym is not None: self.__sym = self.__sym.astype(dtype) + @property + def name(self): + r"""Get the tensor name, does not support Parameter and Buffer. + """ + return self._symvar.name + + @name.setter + def name(self, name: str = None): + r"""Set the tensor name, does not support Parameter and Buffer. + """ + if self.__val is not None: + raise ValueError("name setting is not available for Parameter or Buffer.") + if self.__sym_override is not None: + self.__sym_override = self.__sym_override.rename(name) + if self.__sym is not None: + assert not self.__val + self.__sym = self.__sym.rename(name) + @property def _comp_node(self): if self.__val is not None: diff --git a/python_module/megengine/jit/__init__.py b/python_module/megengine/jit/__init__.py index 6add694d..b3aabdbf 100644 --- a/python_module/megengine/jit/__init__.py +++ b/python_module/megengine/jit/__init__.py @@ -436,6 +436,7 @@ class trace: arg_names=None, append=False, optimize_for_inference=False, + output_names=None, **kwargs ): """ @@ -446,6 +447,8 @@ class trace: :param append: whether output is appended to ``fpath``. :param optimize_for_inference: whether to enable optimize_for_inference pass before dump. + :param output_names: names of the output tensors in the traced function, + will use the default name if does not specify. :param enable_io16xc32: whether to use float16 for I/O between oprs and use float32 as internal computation precision. Note the output var would be @@ -488,6 +491,17 @@ class trace: len(self._args), len(arg_names) ) ) + if isinstance(output_names, str): + output_names = [output_names] + if output_names is None: + output_names = [var.name for var in self._sym_outputs] + elif len(output_names) != len(self._sym_outputs): + raise ValueError( + "len(output_names) should be {}, got {}".format( + len(self._sym_outputs), len(output_names) + ) + ) + optimize_for_inference_args_map = { "enable_io16xc32": "f16_io_f32_comp", "enable_ioc16": "f16_io_comp", @@ -541,6 +555,8 @@ class trace: sym_outputs = mgb.optimize_for_inference( sym_outputs, **optimize_for_inference_kwargs ) + for var, name in zip(sym_outputs, output_names): + var.rename(name) mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append) def get_profile(self): diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index af63cdcf..66aac1e9 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -464,7 +464,7 @@ class Module(metaclass=ABCMeta): # For quantized dtype, the initialized dtype # scale/zero_points maybe invalid, use pretrained dtype instead. if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): - var.set_dtype(to_be_load.dtype) + var.dtype = to_be_load.dtype var.set_value(to_be_load) loaded.append(k) diff --git a/python_module/test/unit/core/test_tensor.py b/python_module/test/unit/core/test_tensor.py index 5f877076..3ac0a4c8 100644 --- a/python_module/test/unit/core/test_tensor.py +++ b/python_module/test/unit/core/test_tensor.py @@ -46,29 +46,46 @@ def test_tensor_set_dtype(): ) t = mge.Parameter(np.ones((3, 4), dtype="float32")) - t.set_dtype(mgb.dtype.qint8(0.1)) + t.dtype = mgb.dtype.qint8(0.1) check_dtype_value(t, 0.1, 10) t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) - t.set_dtype(mgb.dtype.qint8(0.3)) + t.dtype = mgb.dtype.qint8(0.3) check_dtype_value(t, 0.3, 3) t = mge.Buffer(np.ones((3, 4), dtype="float32")) - t.set_dtype(mgb.dtype.qint8(0.1)) + t.dtype = mgb.dtype.qint8(0.1) check_dtype_value(t, 0.1, 10) t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) - t.set_dtype(mgb.dtype.qint8(0.3)) + t.dtype = mgb.dtype.qint8(0.3) check_dtype_value(t, 0.3, 3) t = mge.Buffer(np.ones((3, 4), dtype="float32")) s = t + 1 - s.set_dtype(mgb.dtype.qint8(0.2)) + s.dtype = mgb.dtype.qint8(0.2) check_dtype_value(s, 0.2, 10) - t.set_dtype(mgb.dtype.qint8(0.3)) + t.dtype = mgb.dtype.qint8(0.3) s = t + 1 - s.set_dtype(mgb.dtype.qint8(0.1)) + s.dtype = mgb.dtype.qint8(0.1) check_dtype_value(s, 0.1, 18) - s.set_dtype("float32") + s.dtype = "float32" check_dtype_value(s, 0, 1.8) + + +def test_tensor_name(): + p = mge.Parameter(np.ones((3, 4), dtype="float32")) + assert "shared" in p.name + with pytest.raises(ValueError): + p.name = "Parameter0" + + b = mge.Buffer(np.ones((3, 4), dtype="float32")) + assert "shared" in b.name + with pytest.raises(ValueError): + b.name = "Buffer0" + + s = b + 1 + assert "ADD" in s.name + s.name = "WeightAdd1" + assert s.name == "WeightAdd1" -- GitLab