提交 87f00232 编写于 作者: M Megvii Engine Team

fix(mge/gm): fix missing dtype checking while attach tensors

GitOrigin-RevId: f0aaea99b93472b893eeb3ba35c6293c5f15b122
上级 3726f5cc
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import weakref import weakref
from collections import OrderedDict
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Union
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..core.tensor.dtype import is_differentible_dtype
from ..logger import get_logger from ..logger import get_logger
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.future import Future from ..utils.future import Future
...@@ -208,6 +208,10 @@ class GradManager: ...@@ -208,6 +208,10 @@ class GradManager:
for x in tensors: for x in tensors:
assert isinstance(x, Tensor), "Object to be attached should be Tensor" assert isinstance(x, Tensor), "Object to be attached should be Tensor"
assert is_differentible_dtype(x.dtype), (
"Only tensors of floating point dtype can be attached to get gradients, "
"get tensor dtype: {} and shape: {}".format(x.dtype, x.shape)
)
spec = self._attach_specs.get(id(x)) spec = self._attach_specs.get(id(x))
new_attach = spec is None new_attach = spec is None
if spec is None: if spec is None:
......
...@@ -38,6 +38,10 @@ def is_bfloat16(dtype): ...@@ -38,6 +38,10 @@ def is_bfloat16(dtype):
return dtype is bfloat16 return dtype is bfloat16
def is_differentible_dtype(dtype):
return dtype == np.float32 or dtype == np.float16 or is_bfloat16(dtype)
# quantization dtype related # quantization dtype related
# use namedtuple to make class immutable, comparable and easy to print # use namedtuple to make class immutable, comparable and easy to print
...@@ -114,7 +118,7 @@ def create_quantized_dtype( ...@@ -114,7 +118,7 @@ def create_quantized_dtype(
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None]
): ):
r"""Get quantized dtype with metadata attribute according to _metadata_dict. r"""Get quantized dtype with metadata attribute according to _metadata_dict.
Note that unsigned dtype must have ``zero_point`` and signed dtype must Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`. compiled function from `CompGraph.compile(inputs, outspec)`.
......
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
import pytest import pytest
import megengine as mge import megengine as mge
import megengine.core.tensor.dtype as dtype
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
...@@ -469,3 +470,18 @@ def test_2nd_grad_with_custom_gradient(): ...@@ -469,3 +470,18 @@ def test_2nd_grad_with_custom_gradient():
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5
) )
@pytest.mark.parametrize("invalid_dtype", [np.uint8, np.int8, np.int32])
def test_attach_invalid_tensor_dtype(invalid_dtype):
gm = GradManager()
x = mge.tensor([1], dtype=invalid_dtype)
with pytest.raises(AssertionError):
gm.attach([x])
@pytest.mark.parametrize("differentible_dtype", [np.float32, np.float16])
def test_attach_differentible_tensor_dtype(differentible_dtype):
gm = GradManager()
x = mge.tensor([1], dtype=differentible_dtype)
gm.attach([x])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册