提交 7cb3ad8a 编写于 作者: M Megvii Engine Team 提交者: XindaH

fix(mge/functional): fix one_hot irregular coding style

    From GITHUB 399

    ORIGINAL_AUTHOR=Asthestarsfalll <1186454801@qq.com>
    COPYBARA_INTEGRATE_REVIEW=https://github.com/MegEngine/MegEngine/pull/399 from Asthestarsfalll:master 541bf3af
    GITHUB_PUBLIC_PR_NUMBER=399
    GITHUB_PR_URL=https://github.com/MegEngine/MegEngine/pull/399

GitOrigin-RevId: 5df007207a24f27acc85022b8d08133d8e2be1b9
上级 535784d4
...@@ -1587,8 +1587,10 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: ...@@ -1587,8 +1587,10 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
[0 0 1 0] [0 0 1 0]
[0 0 0 1]] [0 0 0 1]]
""" """
zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) zeros_tensor = zeros(
ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device
)
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)
op = builtin.IndexingSetOneHot(axis=inp.ndim) op = builtin.IndexingSetOneHot(axis=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor) (result,) = apply(op, zeros_tensor, inp, ones_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册