提交 5a01de78 编写于 作者: M Megvii Engine Team

fix(mge): fix scalar transpose

GitOrigin-RevId: c2b9e025c7f305975509e195f6c8e2c2cf7f81b7
上级 6b9ac894
......@@ -392,6 +392,13 @@ class ArrayMethodMixin(abc.ABC):
return _broadcast(self, _expand_args(args))
def transpose(self, *args):
if self.ndim == 0:
assert (
len(args) == 0
), "transpose for scalar does not accept additional args"
ret = self.to(self.device)
setscalar(ret)
return ret
if not args:
args = range(self.ndim)[::-1]
return _transpose(self, _expand_args(args))
......
......@@ -50,3 +50,8 @@ def test_elemementwise():
def test_astype():
a = Tensor(1.0)
assert a.astype("int32").ndim == 0
def test_tranpose():
a = Tensor(1.0)
assert a.transpose().ndim == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册