提交 fbf5cb25 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): upgrade masked_select

上级 ca149e3f
......@@ -302,5 +302,10 @@ class TestTorchFuncsReduction:
}
})
tt1 = ttorch.masked_select(ttx, ttx > 0.3)
assert (tt1 == torch.tensor([1.1799, 0.4652, 1.0866, 1.3533, 0.8139,
0.9073, 2.1392, 0.6403, 0.4041])).all()
assert ttorch.isclose(tt1, torch.tensor([1.1799, 0.4652, 1.0866, 1.3533, 0.8139,
0.9073, 2.1392, 0.6403, 0.4041]), atol=1e-4).all()
tt2 = ttorch.masked_select(ttx, ttx > 0.3, reduce=False)
assert ttorch.isclose(tt2, ttorch.tensor({
'a': [1.1799, 0.4652, 1.0866, 1.3533],
'b': {'x': [0.8139, 0.9073, 2.1392, 0.6403, 0.4041]},
}), atol=1e-4).all()
......@@ -215,7 +215,7 @@ class TestTorchTensorReduction:
[-1.8267, 1.3676, -1.4490, -2.0224]])
t1 = tx.masked_select(tx > 0.3)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([0.9820, 0.8108, 1.0868, 1.3676])).all()
assert ttorch.isclose(t1, torch.tensor([0.9820, 0.8108, 1.0868, 1.3676]), atol=1e-4).all()
ttx = ttorch.tensor({
'a': [[1.1799, 0.4652, -1.7895],
......@@ -227,5 +227,10 @@ class TestTorchTensorReduction:
}
})
tt1 = ttx.masked_select(ttx > 0.3)
assert (tt1 == torch.tensor([1.1799, 0.4652, 1.0866, 1.3533, 0.8139,
0.9073, 2.1392, 0.6403, 0.4041])).all()
assert ttorch.isclose(tt1, torch.tensor([1.1799, 0.4652, 1.0866, 1.3533, 0.8139,
0.9073, 2.1392, 0.6403, 0.4041]), atol=1e-4).all()
tt2 = ttx.masked_select(ttx > 0.3, reduce=False)
assert ttorch.isclose(tt2, ttorch.tensor({
'a': [1.1799, 0.4652, 1.0866, 1.3533],
'b': {'x': [0.8139, 0.9073, 2.1392, 0.6403, 0.4041]},
}), atol=1e-4).all()
......@@ -408,11 +408,34 @@ def std(input, *args, reduce=None, **kwargs):
pass # pragma: no cover
# noinspection PyShadowingBuiltins
@doc_from_base()
# noinspection PyShadowingBuiltins,PyUnusedLocal
@rmreduce()
@func_treelize(return_type=Object)
def masked_select(input, mask, *args, **kwargs):
def _masked_select_r(input, mask, *args, **kwargs):
return torch.masked_select(input, mask, *args, **kwargs)
# noinspection PyShadowingBuiltins
@func_treelize()
def _masked_select_nr(input, mask, *args, **kwargs):
return torch.masked_select(input, mask, *args, **kwargs)
# noinspection PyUnusedLocal
def _ms_determine(mask, *args, out=None, **kwargs):
return False if args or kwargs else None
# noinspection PyUnusedLocal
def _ms_condition(mask, *args, out=None, **kwargs):
return not args and not kwargs
# noinspection PyShadowingBuiltins,PyUnusedLocal
@doc_from_base()
@auto_reduce(_masked_select_r, _masked_select_nr,
_ms_determine, _ms_condition)
def masked_select(input, mask, *args, reduce=None, **kwargs):
"""
Returns a new 1-D tensor which indexes the ``input`` tensor
according to the boolean mask ``mask`` which is a BoolTensor.
......@@ -443,5 +466,10 @@ def masked_select(input, mask, *args, **kwargs):
[-0.0496, 2.1392, 0.6403, 0.4041]])
>>> ttorch.masked_select(tt, tt > 0.3)
tensor([1.1799, 0.4652, 1.0866, 1.3533, 0.8139, 0.9073, 2.1392, 0.6403, 0.4041])
>>> ttorch.masked_select(tt, tt > 0.3, reduce=False)
<Tensor 0x7fcb64456b38>
├── a --> tensor([1.1799, 0.4652, 1.0866, 1.3533])
└── b --> <Tensor 0x7fcb64456a58>
└── x --> tensor([0.8139, 0.9073, 2.1392, 0.6403, 0.4041])
"""
return torch.masked_select(input, mask, *args, **kwargs)
pass # pragma: no cover
......@@ -793,14 +793,33 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.index_select(dim, index)
@doc_from_base()
# noinspection PyShadowingBuiltins,PyUnusedLocal
@rmreduce()
@method_treelize(return_type=Object)
def __masked_select_r(self, mask, *args, **kwargs):
return torch.masked_select(self, mask, *args, **kwargs)
# noinspection PyShadowingBuiltins
@method_treelize()
def __masked_select_nr(self, mask, *args, **kwargs):
return torch.masked_select(self, mask, *args, **kwargs)
# noinspection PyUnusedLocal,PyMethodParameters,PyMethodMayBeStatic
def __ms_determine(mask, *args, out=None, **kwargs):
return False if args or kwargs else None
# noinspection PyUnusedLocal,PyMethodParameters,PyMethodMayBeStatic
def __ms_condition(mask, *args, out=None, **kwargs):
return not args and not kwargs
@doc_from_base()
@auto_reduce(__masked_select_r, __masked_select_nr,
__ms_determine, __ms_condition)
def masked_select(self, mask):
"""
See :func:`treetensor.torch.masked_select`.
"""
return self.masked_select(mask)
pass # pragma: no cover
# noinspection PyUnusedLocal
@post_reduce(torch.std)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册