提交 541bf3af 编写于 作者: Asthestarsfalll's avatar Asthestarsfalll

fix(mge/functional): fix one_hot irregular coding style

上级 409aa489
......@@ -1587,8 +1587,8 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
[0 0 1 0]
[0 0 0 1]]
"""
zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device)
ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device)
zeros_tensor = zeros(list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device)
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)
op = builtin.IndexingSetOneHot(axis=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册