提交 87f00232 编写于 作者: M Megvii Engine Team

fix(mge/gm): fix missing dtype checking while attach tensors

GitOrigin-RevId: f0aaea99b93472b893eeb3ba35c6293c5f15b122
上级 3726f5cc
......@@ -6,11 +6,11 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import weakref
from collections import OrderedDict
from typing import Callable, Iterable, List, Union
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad
from ..core.tensor.dtype import is_differentible_dtype
from ..logger import get_logger
from ..tensor import Tensor
from ..utils.future import Future
......@@ -208,6 +208,10 @@ class GradManager:
for x in tensors:
assert isinstance(x, Tensor), "Object to be attached should be Tensor"
assert is_differentible_dtype(x.dtype), (
"Only tensors of floating point dtype can be attached to get gradients, "
"get tensor dtype: {} and shape: {}".format(x.dtype, x.shape)
)
spec = self._attach_specs.get(id(x))
new_attach = spec is None
if spec is None:
......
......@@ -38,6 +38,10 @@ def is_bfloat16(dtype):
return dtype is bfloat16
def is_differentible_dtype(dtype):
return dtype == np.float32 or dtype == np.float16 or is_bfloat16(dtype)
# quantization dtype related
# use namedtuple to make class immutable, comparable and easy to print
......@@ -114,7 +118,7 @@ def create_quantized_dtype(
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None]
):
r"""Get quantized dtype with metadata attribute according to _metadata_dict.
Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`.
......
......@@ -13,6 +13,7 @@ import numpy as np
import pytest
import megengine as mge
import megengine.core.tensor.dtype as dtype
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
......@@ -469,3 +470,18 @@ def test_2nd_grad_with_custom_gradient():
np.testing.assert_almost_equal(
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5
)
@pytest.mark.parametrize("invalid_dtype", [np.uint8, np.int8, np.int32])
def test_attach_invalid_tensor_dtype(invalid_dtype):
gm = GradManager()
x = mge.tensor([1], dtype=invalid_dtype)
with pytest.raises(AssertionError):
gm.attach([x])
@pytest.mark.parametrize("differentible_dtype", [np.float32, np.float16])
def test_attach_differentible_tensor_dtype(differentible_dtype):
gm = GradManager()
x = mge.tensor([1], dtype=differentible_dtype)
gm.attach([x])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册