提交 5198b783 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/functional): fix expand_dims for scalar

GitOrigin-RevId: 253ea608f7e45a86a90e53cf6159964b2ab54678
上级 88898e63
......@@ -851,7 +851,14 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
axis = get_axes()
ndim = inp.ndim + len(axis)
axis = sorted(i + ndim if i < 0 else i for i in axis)
assert axis, "axis could not be empty"
if inp._isscalar():
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0])
if len(axis) == 1:
inp = copy(inp, device=None)
inp._unsetscalar()
return inp
axis = axis[1:]
op = builtin.AddAxis(axis=axis)
(result,) = apply(op, inp)
return result
......
......@@ -253,6 +253,19 @@ def test_expand_dims(is_varnode):
np.testing.assert_equal(y, yy.numpy())
def test_expand_dims_for_scalar():
x = np.array(1, dtype="float32")
xx = make_tensor(x, None)
for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
y = np.expand_dims(x, axis)
yy = F.expand_dims(xx, axis)
np.testing.assert_equal(y, yy.numpy())
for axis in [1, -2, (1, 2), (-2, -3)]:
np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_elemwise_dtype_promotion(is_varnode):
if is_varnode:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册