提交 4500faf1 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(python_module/mge): one_hot: no default value for num_classes

GitOrigin-RevId: c4a53108806ddfb3e37ef7faa5d88afa31718d74
上级 b9cf0171
...@@ -427,7 +427,7 @@ def batch_norm2d( ...@@ -427,7 +427,7 @@ def batch_norm2d(
return output return output
def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor: def one_hot(inp: Tensor, num_classes: int) -> Tensor:
r""" r"""
Perform one-hot encoding for the input tensor. Perform one-hot encoding for the input tensor.
...@@ -457,8 +457,6 @@ def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor: ...@@ -457,8 +457,6 @@ def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor:
""" """
comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp) comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp)
if num_classes == -1:
num_classes = inp.max() + 1
zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph) zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph)
zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes) zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes)
......
...@@ -8,7 +8,7 @@ from megengine.test import assertTensorClose ...@@ -8,7 +8,7 @@ from megengine.test import assertTensorClose
def test_onehot_low_dimension(): def test_onehot_low_dimension():
inp = tensor(np.arange(1, 4, dtype=np.int32)) inp = tensor(np.arange(1, 4, dtype=np.int32))
out = F.one_hot(inp) out = F.one_hot(inp, num_classes=4)
assertTensorClose( assertTensorClose(
out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)] out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册