diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index b8d6e1c95aa0606a4446c359d03c2fd80d6d005a..a5ef8a0ad51f776e86ad054687e86efa1bb7e0e8 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -235,14 +235,35 @@ class Tensor: return self.__val.dtype return self._symvar.dtype - def set_dtype(self, dtype: str = None): + @dtype.setter + def dtype(self, dtype: str = None): r"""Set the data type of the tensor. """ if self.__val is not None: self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) + elif self.__sym_override is not None: + self.__sym_override = self.__sym_override.astype(dtype) elif self.__sym is not None: self.__sym = self.__sym.astype(dtype) + @property + def name(self): + r"""Get the tensor name, does not support Parameter and Buffer. + """ + return self._symvar.name + + @name.setter + def name(self, name: str = None): + r"""Set the tensor name, does not support Parameter and Buffer. + """ + if self.__val is not None: + raise ValueError("name setting is not available for Parameter or Buffer.") + if self.__sym_override is not None: + self.__sym_override = self.__sym_override.rename(name) + if self.__sym is not None: + assert not self.__val + self.__sym = self.__sym.rename(name) + @property def _comp_node(self): if self.__val is not None: diff --git a/python_module/megengine/jit/__init__.py b/python_module/megengine/jit/__init__.py index 6add694dc6459898061786120bc856446359d1ad..b3aabdbff9a2b7516a85698c97971ee9c62da77a 100644 --- a/python_module/megengine/jit/__init__.py +++ b/python_module/megengine/jit/__init__.py @@ -436,6 +436,7 @@ class trace: arg_names=None, append=False, optimize_for_inference=False, + output_names=None, **kwargs ): """ @@ -446,6 +447,8 @@ class trace: :param append: whether output is appended to ``fpath``. :param optimize_for_inference: whether to enable optimize_for_inference pass before dump. + :param output_names: names of the output tensors in the traced function, + will use the default name if does not specify. :param enable_io16xc32: whether to use float16 for I/O between oprs and use float32 as internal computation precision. Note the output var would be @@ -488,6 +491,17 @@ class trace: len(self._args), len(arg_names) ) ) + if isinstance(output_names, str): + output_names = [output_names] + if output_names is None: + output_names = [var.name for var in self._sym_outputs] + elif len(output_names) != len(self._sym_outputs): + raise ValueError( + "len(output_names) should be {}, got {}".format( + len(self._sym_outputs), len(output_names) + ) + ) + optimize_for_inference_args_map = { "enable_io16xc32": "f16_io_f32_comp", "enable_ioc16": "f16_io_comp", @@ -541,6 +555,8 @@ class trace: sym_outputs = mgb.optimize_for_inference( sym_outputs, **optimize_for_inference_kwargs ) + for var, name in zip(sym_outputs, output_names): + var.rename(name) mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append) def get_profile(self): diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index af63cdcffaaf1629fee56c05de00ecfc902c281d..66aac1e9455b4d0dadf410cb50ebd418ed6c9745 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -464,7 +464,7 @@ class Module(metaclass=ABCMeta): # For quantized dtype, the initialized dtype # scale/zero_points maybe invalid, use pretrained dtype instead. if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): - var.set_dtype(to_be_load.dtype) + var.dtype = to_be_load.dtype var.set_value(to_be_load) loaded.append(k) diff --git a/python_module/test/unit/core/test_tensor.py b/python_module/test/unit/core/test_tensor.py index 5f87707605654f68061b2f7791120cd3cc075fa8..3ac0a4c89648088f719443aec73f5410548d693c 100644 --- a/python_module/test/unit/core/test_tensor.py +++ b/python_module/test/unit/core/test_tensor.py @@ -46,29 +46,46 @@ def test_tensor_set_dtype(): ) t = mge.Parameter(np.ones((3, 4), dtype="float32")) - t.set_dtype(mgb.dtype.qint8(0.1)) + t.dtype = mgb.dtype.qint8(0.1) check_dtype_value(t, 0.1, 10) t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) - t.set_dtype(mgb.dtype.qint8(0.3)) + t.dtype = mgb.dtype.qint8(0.3) check_dtype_value(t, 0.3, 3) t = mge.Buffer(np.ones((3, 4), dtype="float32")) - t.set_dtype(mgb.dtype.qint8(0.1)) + t.dtype = mgb.dtype.qint8(0.1) check_dtype_value(t, 0.1, 10) t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) - t.set_dtype(mgb.dtype.qint8(0.3)) + t.dtype = mgb.dtype.qint8(0.3) check_dtype_value(t, 0.3, 3) t = mge.Buffer(np.ones((3, 4), dtype="float32")) s = t + 1 - s.set_dtype(mgb.dtype.qint8(0.2)) + s.dtype = mgb.dtype.qint8(0.2) check_dtype_value(s, 0.2, 10) - t.set_dtype(mgb.dtype.qint8(0.3)) + t.dtype = mgb.dtype.qint8(0.3) s = t + 1 - s.set_dtype(mgb.dtype.qint8(0.1)) + s.dtype = mgb.dtype.qint8(0.1) check_dtype_value(s, 0.1, 18) - s.set_dtype("float32") + s.dtype = "float32" check_dtype_value(s, 0, 1.8) + + +def test_tensor_name(): + p = mge.Parameter(np.ones((3, 4), dtype="float32")) + assert "shared" in p.name + with pytest.raises(ValueError): + p.name = "Parameter0" + + b = mge.Buffer(np.ones((3, 4), dtype="float32")) + assert "shared" in b.name + with pytest.raises(ValueError): + b.name = "Buffer0" + + s = b + 1 + assert "ADD" in s.name + s.name = "WeightAdd1" + assert s.name == "WeightAdd1"