未验证 提交 e6b56df5 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

Restruct sort and argsort op (#5911)

* restruct sort op

* add sort autotest

* restruct argsort and add autotest

* add tensor autotest

* auto format by CI

* add l1loss autotest

* auto format by CI
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 6374e59d
......@@ -414,6 +414,10 @@
signature: "Tensor DimGather(Tensor x, Tensor indices, *, Int32 dim)"
bind_python: True
- name: "arg_sort"
signature: "Tensor ArgSort(Tensor in, *, String direction)"
bind_python: True
- name: "gather_nd"
signature: "Tensor GatherNd(Tensor params, Tensor indices, *)"
bind_python: True
......
......@@ -509,6 +509,22 @@ class DimScatterMulScalarFunctor {
std::shared_ptr<OpExpr> op_;
};
class ArgSortFunctor {
public:
ArgSortFunctor() {
op_ = CHECK_JUST(one::OpBuilder("arg_sort").Input("in").Output("out").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,
const std::string direction) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::string>("direction", direction));
return OpInterpUtil::Dispatch<Tensor>(*op_, {in}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class GatherNdFunctor {
public:
GatherNdFunctor() {
......@@ -1425,6 +1441,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ExpandDimsFunctor>("ExpandDims");
m.add_functor<impl::GatherFunctor>("Gather");
m.add_functor<impl::DimGatherFunctor>("DimGather");
m.add_functor<impl::ArgSortFunctor>("ArgSort");
m.add_functor<impl::GatherNdFunctor>("GatherNd");
m.add_functor<impl::ScatterNdFunctor>("ScatterNd");
m.add_functor<impl::ScatterNdLikeFunctor>("ScatterNdLike");
......
......@@ -26,25 +26,18 @@ class Argsort(Module):
def __init__(self, dim: int = -1, descending: bool = False) -> None:
super().__init__()
self.dim = dim
direction = "DESCENDING" if descending else "ASCENDING"
self._argsort_op = (
flow.builtin_op("arg_sort")
.Input("in")
.Output("out")
.Attr("direction", direction)
.Build()
)
self.direction = "DESCENDING" if descending else "ASCENDING"
def forward(self, input):
num_dims = len(input.shape)
dim = self.dim if self.dim >= 0 else self.dim + num_dims
assert 0 <= dim < num_dims, "dim out of range"
if dim == num_dims - 1:
return self._argsort_op(input)[0]
return flow.F.arg_sort(input, self.direction)
else:
perm = get_perm_when_transpose_axis_to_last_dim(num_dims, dim)
x = flow.F.transpose(input, perm=perm)
x = self._argsort_op(x)[0]
x = flow.F.arg_sort(x, self.direction)
return flow.F.transpose(x, perm=get_inversed_perm(perm))
......
......@@ -26,26 +26,19 @@ class Sort(Module):
def __init__(self, dim: int = -1, descending: bool = False) -> None:
super().__init__()
self.dim = dim
direction = "DESCENDING" if descending else "ASCENDING"
self._argsort_op = (
flow.builtin_op("arg_sort")
.Input("in")
.Output("out")
.Attr("direction", direction)
.Build()
)
self.direction = "DESCENDING" if descending else "ASCENDING"
def forward(self, input):
num_dims = len(input.shape)
dim = self.dim if self.dim >= 0 else self.dim + num_dims
assert 0 <= dim < num_dims, "dim out of range"
if dim == num_dims - 1:
indices = self._argsort_op(input)[0]
indices = flow.F.arg_sort(input, self.direction)
return (flow.gather(input, indices, dim), indices)
else:
perm = get_perm_when_transpose_axis_to_last_dim(num_dims, dim)
x = flow.F.transpose(input, perm=perm)
indices = self._argsort_op(x)[0]
indices = flow.F.arg_sort(x, self.direction)
indices = flow.F.transpose(indices, perm=get_inversed_perm(perm))
return (flow.gather(input, indices, dim), indices)
......
......@@ -22,6 +22,7 @@ from test_util import GenArgList, type_name_to_flow_type
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
def _test_argsort(test_case, data_shape, axis, descending, data_type, device):
......@@ -62,6 +63,15 @@ class TestArgsort(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest(auto_backward=False)
def test_argsort_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4).to(device)
y = torch.argsort(
x, dim=random(low=-4, high=4).to(int), descending=random_bool()
)
return y
if __name__ == "__main__":
unittest.main()
......@@ -22,6 +22,7 @@ from test_util import GenArgList
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
def _np_l1loss(np_input, np_target):
......@@ -75,6 +76,28 @@ class TestL1LossModule(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest()
def test_l1loss_module_with_random_data(test_case):
k = random(1, 6).to(int)
dim0 = random(1, 10).to(int)
dim1 = random(1, 10).to(int)
dim2 = random(1, 10).to(int)
dim3 = random(1, 10).to(int)
dim4 = random(1, 10).to(int)
reduction = oneof("none", "sum", "mean")
loss = torch.nn.L1Loss(reduction=reduction | nothing())
loss.train(random())
device = random_device()
loss.to(device)
input = random_pytorch_tensor(
ndim=k, dim0=dim0, dim1=dim1, dim2=dim2, dim3=dim3, dim4=dim4
).to(device)
target = random_pytorch_tensor(
ndim=k, dim0=dim0, dim1=dim1, dim2=dim2, dim3=dim3, dim4=dim4
).to(device)
y = loss(input, target)
return y
if __name__ == "__main__":
unittest.main()
......@@ -22,6 +22,7 @@ from test_util import GenArgList, type_name_to_flow_type
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
def _test_sort(test_case, data_shape, axis, descending, data_type, device):
......@@ -75,6 +76,13 @@ class TestSort(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest(auto_backward=False)
def test_sort_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4).to(device)
y = torch.sort(x, dim=random(low=-4, high=4).to(int), descending=random_bool())
return y[0], y[1]
if __name__ == "__main__":
unittest.main()
......@@ -482,6 +482,20 @@ class TestTensor(flow.unittest.TestCase):
y = x.acosh()
return y
@autotest(auto_backward=False)
def test_sort_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4).to(device)
y = x.sort(dim=random(low=-4, high=4).to(int), descending=random_bool())
return y[0], y[1]
@autotest(auto_backward=False)
def test_argsort_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4).to(device)
y = x.argsort(dim=random(low=-4, high=4).to(int), descending=random_bool())
return y
def test_mean(test_case):
input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
of_out = input.mean(dim=0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册