From 7cb3ad8a3efc0b971c86eeab16722cc44d81e16b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 8 Dec 2021 12:34:25 +0800 Subject: [PATCH] fix(mge/functional): fix one_hot irregular coding style From GITHUB 399 ORIGINAL_AUTHOR=Asthestarsfalll <1186454801@qq.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/MegEngine/MegEngine/pull/399 from Asthestarsfalll:master 541bf3af29520b1caee552012da7568afb206632 GITHUB_PUBLIC_PR_NUMBER=399 GITHUB_PR_URL=https://github.com/MegEngine/MegEngine/pull/399 GitOrigin-RevId: 5df007207a24f27acc85022b8d08133d8e2be1b9 --- imperative/python/megengine/functional/nn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index c127cc572..9ababcebe 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1587,8 +1587,10 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: [0 0 1 0] [0 0 0 1]] """ - zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) - ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) + zeros_tensor = zeros( + list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device + ) + ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) op = builtin.IndexingSetOneHot(axis=inp.ndim) (result,) = apply(op, zeros_tensor, inp, ones_tensor) -- GitLab