提交 07f8765f 编写于 作者: S scxfjiang 提交者: GitHub

Dev enhance sort ops (#3828)

* draft

* test sort op

* argsort

* argmax

* quick imple for top_k, to be tested

* feat

* test argmax

* pass all tests

* fix type

* batch_axis_non_change true

* final

* format

* refine by review

* accurate axis check

* Update oneflow/python/ops/transpose_util.py
Co-authored-by: NYinggang Wang <wyg19970408@gmail.com>
Co-authored-by: NMardino <42901638+MARD1NO@users.noreply.github.com>
Co-authored-by: NYinggang Wang <wyg19970408@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Former-commit-id: a29c228c
上级 8f940881
......@@ -28,6 +28,8 @@ import oneflow.python.framework.dtype as dtype_util
import oneflow.python.framework.module as module_util
import oneflow.python.ops.math_unary_elementwise_ops as math_unary_elementwise_ops
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.ops.transpose_util import get_perm_when_transpose_axis_to_last_dim
from oneflow.python.ops.transpose_util import get_inversed_perm
@oneflow_export("math.add")
......@@ -1424,17 +1426,38 @@ def elem_cnt(
return remote_blob_util.RemoteBlob(out_lbi)
def _top_k_at_last_dim(
input: remote_blob_util.BlobDef,
k: int = 1,
sorted: bool = True,
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
return (
flow.user_op_builder(name if name is not None else id_util.UniqueStr("TopK_"))
.Op("top_k")
.Input("in", [input])
.Output("out")
.Attr("k", k)
.Attr("sorted", sorted)
.Build()
.InferAndTryRun()
.RemoteBlobList()[0]
)
@oneflow_export("math.top_k")
def top_k(
input: remote_blob_util.BlobDef,
axis: int = -1,
k: int = 1,
sorted: bool = True,
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
"""Finds the indices of the k largest entries for the last dimension, the difference between other framework is that oneflow only return the indices.
"""Finds the indices of the k largest entries at specified axis, the difference between other framework is that oneflow only return the indices.
Args:
input (remote_blob_util.BlobDef): The input Blob
axis (int, optional): dimension to be calculated. Defaults to the last dim (-1)
k (int, optional): Number of top elements to look for along the last dimension. Defaults to 1.
sorted (bool, optional): If true the resulting k elements will be sorted by the values in descending order. Defaults to True.
name (Optional[str], optional): The name for the operation. Defaults to None.
......@@ -1461,13 +1484,29 @@ def top_k(
# out [2 3]
"""
name = name if name is not None else id_util.UniqueStr("TopK_")
num_axes = len(input.shape)
axis = axis if axis >= 0 else axis + num_axes
assert 0 <= axis < num_axes, "axis out of range"
if axis == num_axes - 1:
return _top_k_at_last_dim(input, k, sorted, name)
else:
perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
x = flow.transpose(input, perm, False, True, name + "_transpose")
x = _top_k_at_last_dim(x, k, sorted, name)
return flow.transpose(
x, get_inversed_perm(perm), False, True, name + "_inverse_transpose"
)
def _argmax_at_last_dim(
input: remote_blob_util.BlobDef, name: Optional[str] = None
) -> remote_blob_util.BlobDef:
return (
flow.user_op_builder(name if name is not None else id_util.UniqueStr("TopK_"))
.Op("top_k")
flow.user_op_builder(name if name is not None else id_util.UniqueStr("ArgMax_"))
.Op("argmax")
.Input("in", [input])
.Output("out")
.Attr("k", k)
.Attr("sorted", sorted)
.Build()
.InferAndTryRun()
.RemoteBlobList()[0]
......@@ -1476,12 +1515,13 @@ def top_k(
@oneflow_export("math.argmax")
def argmax(
input: remote_blob_util.BlobDef, name: Optional[str] = None
input: remote_blob_util.BlobDef, axis: int = -1, name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
"""The op computes the index with the largest value of a Blob.
"""The op computes the index with the largest value of a Blob at specified axis.
Args:
input (remote_blob_util.BlobDef): Input Blob
axis (int, optional): dimension to be calculated. Defaults to the last dim (-1)
name (Optional[str], optional): The name for the operation. Defaults to None.
Returns:
......@@ -1508,15 +1548,22 @@ def argmax(
# out [2 1]
"""
return (
flow.user_op_builder(name if name is not None else id_util.UniqueStr("ArgMax_"))
.Op("argmax")
.Input("in", [input])
.Output("out")
.Build()
.InferAndTryRun()
.RemoteBlobList()[0]
)
name = name if name is not None else id_util.UniqueStr("ArgMax_")
num_axes = len(input.shape)
axis = axis if axis >= 0 else axis + num_axes
assert 0 <= axis < num_axes, "axis out of range"
if axis == num_axes - 1:
return _argmax_at_last_dim(input, name)
else:
perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
x = flow.transpose(input, perm, False, True, name + "_transpose")
x = _argmax_at_last_dim(x, name)
x = flow.expand_dims(x, -1, name + "_expand_dims")
x = flow.transpose(
x, get_inversed_perm(perm), False, True, name + "_inverse_transpose"
)
x = flow.squeeze(x, [axis], name + "_squeeze")
return x
@oneflow_export("math.broadcast_to_compatible_with", "broadcast_to_compatible_with")
......
......@@ -21,27 +21,49 @@ import oneflow as flow
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.remote_blob as remote_blob_util
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.ops.transpose_util import get_perm_when_transpose_axis_to_last_dim
from oneflow.python.ops.transpose_util import get_inversed_perm
def _sort_at_last_dim(
input: remote_blob_util.BlobDef,
direction: str = "ASCENDING",
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
assert direction in ["ASCENDING", "DESCENDING"]
return (
flow.user_op_builder(name if name is not None else id_util.UniqueStr("Sort_"))
.Op("sort")
.Input("in", [input])
.Output("out")
.Attr("direction", direction)
.Build()
.InferAndTryRun()
.RemoteBlobList()[0]
)
@oneflow_export("sort")
def sort(
input: remote_blob_util.BlobDef,
axis: int = -1,
direction: str = "ASCENDING",
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
"""This operator sorts the input Blob.
"""This operator sorts the input Blob at specified axis.
Args:
input (remote_blob_util.BlobDef): A Blob
axis (int, optional): dimension to be sorted. Defaults to the last dim (-1)
direction (str, optional): The direction in which to sort the Blob values. If the direction is "ASCENDING", The order of input will be sorted as ascending, else, the order of input will be sorted as descending. Defaults to "ASCENDING".
name (Optional[str], optional): The name for the operation. Defaults to None.
Returns:
remote_blob_util.BlobDef: The sorted Blob
For example:
For example:
.. code-block:: python
.. code-block:: python
import oneflow as flow
import numpy as np
......@@ -51,8 +73,7 @@ def sort(
@flow.global_function()
def sort_Job(x: tp.Numpy.Placeholder((5, ))
) -> tp.Numpy:
return flow.sort(input=x,
direction='ASCENDING')
return flow.sort(input=x)
x = np.array([10, 2, 9, 3, 7]).astype("float32")
out = sort_Job(x)
......@@ -60,10 +81,33 @@ def sort(
# out [ 2. 3. 7. 9. 10.]
"""
assert direction in ["ASCENDING", "DESCENDING"]
name = name if name is not None else id_util.UniqueStr("Sort_")
num_axes = len(input.shape)
axis = axis if axis >= 0 else axis + num_axes
assert 0 <= axis < num_axes, "axis out of range"
if axis == num_axes - 1:
return _sort_at_last_dim(input, direction, name)
else:
perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
x = flow.transpose(input, perm, False, True, name + "_transpose")
x = _sort_at_last_dim(x, direction, name)
return flow.transpose(
x, get_inversed_perm(perm), False, True, name + "_inverse_transpose"
)
def _argsort_at_last_dim(
input: remote_blob_util.BlobDef,
direction: str = "ASCENDING",
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
assert direction in ["ASCENDING", "DESCENDING"]
return (
flow.user_op_builder(name if name is not None else id_util.UniqueStr("Sort_"))
.Op("sort")
flow.user_op_builder(
name if name is not None else id_util.UniqueStr("ArgSort_")
)
.Op("arg_sort")
.Input("in", [input])
.Output("out")
.Attr("direction", direction)
......@@ -76,18 +120,20 @@ def sort(
@oneflow_export("argsort")
def argsort(
input: remote_blob_util.BlobDef,
axis: int = -1,
direction: str = "ASCENDING",
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
"""This operator sorts the input Blob and return the indices of sorted Blob.
"""This operator sorts the input Blob at specified axis and return the indices of the sorted Blob.
Args:
input (remote_blob_util.BlobDef): A Blob
axis (int, optional): dimension to be sorted. Defaults to the last dim (-1)
direction (str, optional): The direction in which to sort the Blob values. If the direction is "ASCENDING", The order of input will be sorted as ascending, else, the order of input will be sorted as descending. Defaults to "ASCENDING".
name (Optional[str], optional): The name for the operation. Defaults to None.
Returns:
remote_blob_util.BlobDef: The indices of sorted Blob
remote_blob_util.BlobDef: The indices of the sorted Blob
For example:
......@@ -101,8 +147,7 @@ def argsort(
@flow.global_function()
def argsort_Job(x: tp.Numpy.Placeholder((5, ))
) -> tp.Numpy:
return flow.argsort(input=x,
direction='ASCENDING')
return flow.argsort(input=x)
x = np.array([10, 2, 9, 3, 7]).astype("float32")
out = argsort_Job(x)
......@@ -111,15 +156,16 @@ def argsort(
"""
assert direction in ["ASCENDING", "DESCENDING"]
return (
flow.user_op_builder(
name if name is not None else id_util.UniqueStr("ArgSort_")
name = name if name is not None else id_util.UniqueStr("ArgSort_")
num_axes = len(input.shape)
axis = axis if axis >= 0 else axis + num_axes
assert 0 <= axis < num_axes, "axis out of range"
if axis == num_axes - 1:
return _argsort_at_last_dim(input, direction, name)
else:
perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
x = flow.transpose(input, perm, False, True, name + "_transpose")
x = _argsort_at_last_dim(x, direction, name)
return flow.transpose(
x, get_inversed_perm(perm), False, True, name + "_inverse_transpose"
)
.Op("arg_sort")
.Input("in", [input])
.Output("out")
.Attr("direction", direction)
.Build()
.InferAndTryRun()
.RemoteBlobList()[0]
)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
from typing import Sequence
def is_perm(perm: Sequence[int],) -> bool:
return list(range(len(perm))) == sorted(list(perm))
# get the perm when you want to transpose specified axis to the last dimension
def get_perm_when_transpose_axis_to_last_dim(num_axes: int, axis: int,) -> tuple:
axis = axis if axis >= 0 else axis + num_axes
assert 0 <= axis < num_axes, "axis out of range"
perm = [dim if dim < axis else dim + 1 for dim in range(num_axes - 1)]
perm.append(axis)
return tuple(perm)
# x == transpose(transpose(x, perm), get_inversed_perm(perm))
def get_inversed_perm(perm: Sequence[int],) -> tuple:
assert is_perm(perm)
inversed_perm = [-1] * len(perm)
for i in range(len(perm)):
inversed_perm[perm[i]] = i
return tuple(inversed_perm)
......@@ -27,7 +27,7 @@ for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
def compare_with_tensorflow(device_type, in_shape, data_type):
def compare_with_tensorflow(device_type, in_shape, axis, data_type):
assert device_type in ["gpu", "cpu"]
assert data_type in ["float32", "double", "int8", "int32", "int64"]
flow.clear_default_session()
......@@ -43,13 +43,13 @@ def compare_with_tensorflow(device_type, in_shape, data_type):
)
):
with flow.scope.placement(device_type, "0:0"):
return flow.math.argmax(input)
return flow.math.argmax(input, axis)
input = (np.random.random(in_shape) * 100).astype(type_name_to_np_type[data_type])
# OneFlow
of_out = ArgMaxJob([input]).get().numpy_list()[0]
# TensorFlow
tf_out = tf.math.argmax(input, -1).numpy()
tf_out = tf.math.argmax(input, axis).numpy()
tf_out = np.array([tf_out]) if isinstance(tf_out, np.int64) else tf_out
assert np.array_equal(of_out, tf_out)
......@@ -65,6 +65,17 @@ def gen_arg_list():
(10, 10, 2000),
(10, 10000),
]
arg_dict["axis"] = [-1]
arg_dict["data_type"] = ["float32", "double", "int32", "int64"]
return GenArgList(arg_dict)
def gen_arg_list_for_test_axis():
arg_dict = OrderedDict()
arg_dict["device_type"] = ["gpu", "cpu"]
arg_dict["in_shape"] = [(10, 10, 20, 30)]
arg_dict["axis"] = [-2, -1, 0, 1, 2]
arg_dict["data_type"] = ["float32", "double", "int32", "int64"]
return GenArgList(arg_dict)
......@@ -75,6 +86,8 @@ class TestArgmax(flow.unittest.TestCase):
def test_argmax(test_case):
for arg in gen_arg_list():
compare_with_tensorflow(*arg)
for arg in gen_arg_list_for_test_axis():
compare_with_tensorflow(*arg)
if __name__ == "__main__":
......
......@@ -27,7 +27,7 @@ for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
def compare_with_tensorflow(device_type, in_shape, direction, data_type):
def compare_with_tensorflow(device_type, in_shape, axis, direction, data_type):
assert device_type in ["gpu", "cpu"]
assert data_type in ["float32", "double", "int8", "int32", "int64"]
flow.clear_default_session()
......@@ -43,13 +43,13 @@ def compare_with_tensorflow(device_type, in_shape, direction, data_type):
)
):
with flow.scope.placement(device_type, "0:0"):
return flow.argsort(input, direction)
return flow.argsort(input, axis, direction)
input = (np.random.random(in_shape) * 100).astype(type_name_to_np_type[data_type])
# OneFlow
of_out = ArgSortJob([input]).get().numpy_list()[0]
# TensorFlow
tf_out = tf.argsort(input, axis=-1, direction=direction)
tf_out = tf.argsort(input, axis, direction)
assert np.array_equal(of_out, tf_out.numpy())
......@@ -58,6 +58,18 @@ def gen_arg_list():
arg_dict = OrderedDict()
arg_dict["device_type"] = ["cpu", "gpu"]
arg_dict["in_shape"] = [(100,), (100, 100), (10, 10, 200)]
arg_dict["axis"] = [-1]
arg_dict["direction"] = ["ASCENDING", "DESCENDING"]
arg_dict["data_type"] = ["float32", "double", "int32", "int64"]
return GenArgList(arg_dict)
def gen_arg_list_for_test_axis():
arg_dict = OrderedDict()
arg_dict["device_type"] = ["cpu", "gpu"]
arg_dict["in_shape"] = [(10, 10, 20)]
arg_dict["axis"] = [-2, -1, 0, 1, 2]
arg_dict["direction"] = ["ASCENDING", "DESCENDING"]
arg_dict["data_type"] = ["float32", "double", "int32", "int64"]
......@@ -69,6 +81,8 @@ class TestArgsort(flow.unittest.TestCase):
def test_argsort(test_case):
for arg in gen_arg_list():
compare_with_tensorflow(*arg)
for arg in gen_arg_list_for_test_axis():
compare_with_tensorflow(*arg)
if __name__ == "__main__":
......
......@@ -27,7 +27,7 @@ for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
def compare_with_tensorflow(device_type, in_shape, direction, data_type):
def compare_with_tensorflow(device_type, in_shape, axis, direction, data_type):
assert device_type in ["gpu", "cpu"]
assert data_type in ["float32", "double", "int8", "int32", "int64"]
flow.clear_default_session()
......@@ -43,13 +43,13 @@ def compare_with_tensorflow(device_type, in_shape, direction, data_type):
)
):
with flow.scope.placement(device_type, "0:0"):
return flow.sort(input, direction)
return flow.sort(input, axis, direction)
input = (np.random.random(in_shape) * 100).astype(type_name_to_np_type[data_type])
# OneFlow
of_out = SortJob([input]).get().numpy_list()[0]
# TensorFlow
tf_out = tf.sort(input, axis=-1, direction=direction)
tf_out = tf.sort(input, axis, direction)
assert np.array_equal(of_out, tf_out.numpy())
......@@ -58,6 +58,18 @@ def gen_arg_list():
arg_dict = OrderedDict()
arg_dict["device_type"] = ["cpu", "gpu"]
arg_dict["in_shape"] = [(100,), (100, 100), (10, 10, 200)]
arg_dict["axis"] = [-1]
arg_dict["direction"] = ["ASCENDING", "DESCENDING"]
arg_dict["data_type"] = ["float32", "double", "int32", "int64"]
return GenArgList(arg_dict)
def gen_arg_list_for_test_axis():
arg_dict = OrderedDict()
arg_dict["device_type"] = ["cpu", "gpu"]
arg_dict["in_shape"] = [(10, 10, 20)]
arg_dict["axis"] = [-2, -1, 0, 1, 2]
arg_dict["direction"] = ["ASCENDING", "DESCENDING"]
arg_dict["data_type"] = ["float32", "double", "int32", "int64"]
......@@ -69,6 +81,8 @@ class TestSort(flow.unittest.TestCase):
def test_sort(test_case):
for arg in gen_arg_list():
compare_with_tensorflow(*arg)
for arg in gen_arg_list_for_test_axis():
compare_with_tensorflow(*arg)
if __name__ == "__main__":
......
......@@ -21,13 +21,15 @@ import oneflow as flow
import tensorflow as tf
from test_util import GenArgList, type_name_to_flow_type, type_name_to_np_type
import oneflow.typing as oft
from oneflow.python.ops.transpose_util import get_perm_when_transpose_axis_to_last_dim
from oneflow.python.ops.transpose_util import get_inversed_perm
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
def compare_with_tensorflow(device_type, in_shape, k, data_type, sorted):
def compare_with_tensorflow(device_type, in_shape, axis, k, data_type, sorted):
assert device_type in ["gpu", "cpu"]
assert data_type in ["float32", "double", "int8", "int32", "int64"]
flow.clear_default_session()
......@@ -43,16 +45,20 @@ def compare_with_tensorflow(device_type, in_shape, k, data_type, sorted):
)
):
with flow.scope.placement(device_type, "0:0"):
return flow.math.top_k(input, k, sorted)
return flow.math.top_k(input, axis, k, sorted)
input = (np.random.random(in_shape) * 100).astype(type_name_to_np_type[data_type])
# OneFlow
of_out = TopKJob([input]).get().numpy_list()[0]
# TensorFlow
if k <= in_shape[-1]:
_, tf_out = tf.math.top_k(input, k, sorted)
if k <= in_shape[axis]:
perm = get_perm_when_transpose_axis_to_last_dim(len(in_shape), axis)
x = tf.transpose(input, perm)
_, indices = tf.math.top_k(x, k, sorted)
tf_out = tf.transpose(indices, get_inversed_perm(perm))
else:
tf_out = tf.argsort(input, axis=-1, direction="DESCENDING", stable=True)
tf_out = tf.argsort(input, axis, direction="DESCENDING", stable=True)
assert np.array_equal(of_out, tf_out.numpy())
......@@ -61,6 +67,19 @@ def gen_arg_list():
arg_dict = OrderedDict()
arg_dict["device_type"] = ["cpu", "gpu"]
arg_dict["in_shape"] = [(100,), (100, 100), (10, 500), (10, 10, 500)]
arg_dict["axis"] = [-1]
arg_dict["k"] = [1, 50, 200]
arg_dict["data_type"] = ["float32", "double", "int32", "int64"]
arg_dict["sorted"] = [True]
return GenArgList(arg_dict)
def gen_arg_list_for_test_axis():
arg_dict = OrderedDict()
arg_dict["device_type"] = ["cpu", "gpu"]
arg_dict["in_shape"] = [(10, 10, 500)]
arg_dict["axis"] = [-2, -1, 0, 1, 2]
arg_dict["k"] = [1, 50, 200]
arg_dict["data_type"] = ["float32", "double", "int32", "int64"]
arg_dict["sorted"] = [True]
......@@ -73,6 +92,8 @@ class TestTopK(flow.unittest.TestCase):
def test_top_k(test_case):
for arg in gen_arg_list():
compare_with_tensorflow(*arg)
for arg in gen_arg_list_for_test_axis():
compare_with_tensorflow(*arg)
if __name__ == "__main__":
......
......@@ -42,7 +42,7 @@ REGISTER_USER_OP("transpose")
const auto& perm = ctx->Attr<std::vector<int32_t>>("perm");
CHECK_EQ_OR_RETURN(perm.size(), in_shape.NumAxes());
CheckIsPerm(perm);
if (perm.at(0) != 0) { CHECK_OR_RETURN(!in_tensor_desc->is_dynamic()); }
// if (perm.at(0) != 0) { CHECK_OR_RETURN(!in_tensor_desc->is_dynamic()); }
*out_tensor_desc = *in_tensor_desc;
FOR_RANGE(size_t, i, 0, perm.size()) { out_shape->Set(i, in_shape.At(perm[i])); }
return Maybe<void>::Ok();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册