Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5f79c7fb
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5f79c7fb
编写于
7月 02, 2018
作者:
Y
Yibing Liu
提交者:
GitHub
7月 02, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11174 from kuke/argsort_dev
Add the argsort operator
上级
66c91911
9386ac0a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
445 addition
and
0 deletion
+445
-0
doc/fluid/api/layers.rst
doc/fluid/api/layers.rst
+8
-0
paddle/fluid/operators/argsort_op.cc
paddle/fluid/operators/argsort_op.cc
+87
-0
paddle/fluid/operators/argsort_op.cu
paddle/fluid/operators/argsort_op.cu
+151
-0
paddle/fluid/operators/argsort_op.h
paddle/fluid/operators/argsort_op.h
+81
-0
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+53
-0
python/paddle/fluid/tests/unittests/test_argsort_op.py
python/paddle/fluid/tests/unittests/test_argsort_op.py
+56
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+9
-0
未找到文件。
doc/fluid/api/layers.rst
浏览文件 @
5f79c7fb
...
@@ -1468,6 +1468,14 @@ argmax
...
@@ -1468,6 +1468,14 @@ argmax
.. autofunction:: paddle.fluid.layers.argmax
.. autofunction:: paddle.fluid.layers.argmax
:noindex:
:noindex:
.. _api_fluid_layers_argsort:
argsort
-------
.. autofunction:: paddle.fluid.layers.argsort
:noindex:
.. _api_fluid_layers_ones:
.. _api_fluid_layers_ones:
ones
ones
...
...
paddle/fluid/operators/argsort_op.cc
0 → 100644
浏览文件 @
5f79c7fb
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/argsort_op.h"
namespace
paddle
{
namespace
operators
{
class
ArgsortOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ArgsortOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ArgsortOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Indices"
),
"Output(Indices) of ArgsortOp should not be null."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
auto
num_dims
=
in_dims
.
size
();
PADDLE_ENFORCE
(
axis
<
num_dims
,
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X)'s "
"rank %d."
,
axis
,
num_dims
);
PADDLE_ENFORCE
(
axis
>=
-
num_dims
,
"Attr(axis) %d of ArgsortOp must be not less than "
"-rank(Input(X)) (%d)."
,
axis
,
num_dims
);
ctx
->
SetOutputDim
(
"Out"
,
in_dims
);
ctx
->
SetOutputDim
(
"Indices"
,
in_dims
);
ctx
->
ShareLoD
(
"X"
,
"Out"
);
ctx
->
ShareLoD
(
"X"
,
"Indices"
);
}
};
class
ArgsortOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input of Argsort op."
);
AddOutput
(
"Out"
,
"(Tensor) The sorted tensor of Argsort op, with the same "
"shape as Input(X)."
);
AddOutput
(
"Indices"
,
"(Tensor) The indices of a tensor giving the sorted order, with "
"the same shape as Input(X)."
);
AddComment
(
R"DOC(
Argsort operator
Performs sorting on the input tensor along the given axis and outputs two
tensors, Output(Out) and Output(Indices). They reserve the same shape
with Input(X), and Output(Out) represents the sorted tensor while
Output(Indices) gives the sorted order along the given axis Attr(axis).
)DOC"
);
AddAttr
<
int
>
(
"axis"
,
"(int, default -1) The axis along which to sort the tensor. "
"When axis < 0, the actual axis will be the |axis|'th "
"counting backwards. Default -1, the last dimension."
)
.
SetDefault
(
-
1
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
argsort
,
ops
::
ArgsortOp
,
ops
::
ArgsortOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
argsort
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
ArgsortKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/fluid/operators/argsort_op.cu
0 → 100644
浏览文件 @
5f79c7fb
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
const
int
kMaxRank
=
9
;
// The max rank of a tensor allowed in Fluid
__global__
void
ComputeTargetIdx
(
const
int64_t
*
in_dims
,
int
dims_size
,
int
axis
,
int64_t
n
,
int64_t
*
trg_idx
,
int64_t
*
med_ids
)
{
int64_t
index
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
index
<
n
)
{
int64_t
shape_out_axis
[
kMaxRank
-
1
]
=
{
0
};
int64_t
dims_out_axis
[
kMaxRank
-
1
]
=
{
0
};
int64_t
tmp
=
index
;
int64_t
pos_in_axis
=
0
;
int64_t
i
=
dims_size
-
2
;
int64_t
dim_axis
=
0
;
for
(
int64_t
j
=
dims_size
-
1
;
j
>=
0
;
--
j
)
{
int64_t
dim
=
in_dims
[
j
];
if
(
j
!=
axis
)
{
shape_out_axis
[
i
]
=
tmp
%
dim
;
dims_out_axis
[
i
]
=
dim
;
i
--
;
}
else
{
dim_axis
=
dim
;
pos_in_axis
=
tmp
%
dim_axis
;
}
tmp
/=
dim
;
}
int64_t
group
=
(
dims_size
>
1
)
?
shape_out_axis
[
0
]
:
0
;
for
(
int64_t
j
=
0
;
j
<
dims_size
-
2
;
++
j
)
{
group
=
group
*
dims_out_axis
[
j
+
1
]
+
shape_out_axis
[
j
+
1
];
}
int64_t
traget_idx
=
group
*
dim_axis
+
pos_in_axis
;
trg_idx
[
index
]
=
traget_idx
;
med_ids
[
traget_idx
]
=
pos_in_axis
;
}
}
template
<
typename
T
>
__global__
void
PermuteInData
(
const
T
*
in
,
const
int64_t
*
trg_idx
,
int64_t
n
,
T
*
med_out
)
{
int
index
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
index
<
n
)
{
med_out
[
trg_idx
[
index
]]
=
in
[
index
];
}
}
template
<
typename
T
>
__global__
void
Sort
(
int64_t
axis_dim
,
int64_t
groups
,
T
*
med_out
,
int64_t
*
med_ids
)
{
int
index
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
index
<
groups
)
{
thrust
::
sort_by_key
(
thrust
::
device
,
med_out
+
index
*
axis_dim
,
med_out
+
axis_dim
*
(
1
+
index
),
med_ids
+
index
*
axis_dim
);
}
}
template
<
typename
T
>
__global__
void
PermuteMediateData
(
const
T
*
med_out
,
const
int64_t
*
med_ids
,
const
int64_t
*
trg_idx
,
int64_t
n
,
T
*
out
,
int64_t
*
indices
)
{
int
index
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
index
<
n
)
{
out
[
index
]
=
med_out
[
trg_idx
[
index
]];
indices
[
index
]
=
med_ids
[
trg_idx
[
index
]];
}
}
template
<
typename
T
>
class
ArgsortOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
indices
=
ctx
.
Output
<
Tensor
>
(
"Indices"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
in_dims
=
input
->
dims
();
axis
=
(
axis
<
0
)
?
(
in_dims
.
size
()
+
axis
)
:
axis
;
const
T
*
in_data
=
input
->
data
<
T
>
();
T
*
out_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
*
ids_data
=
indices
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
());
int64_t
numel
=
input
->
numel
();
int64_t
groups
=
numel
/
in_dims
[
axis
];
std
::
vector
<
int64_t
>
in_dims_vec
=
vectorize
(
in_dims
);
thrust
::
device_vector
<
int64_t
>
in_dims_dev
(
in_dims_vec
.
begin
(),
in_dims_vec
.
end
());
int64_t
*
in_dims_data
=
thrust
::
raw_pointer_cast
(
in_dims_dev
.
data
());
// Mediate tensor for sorting data and indices
Tensor
mediate_output
,
mediate_indices
;
T
*
med_out_data
=
mediate_output
.
mutable_data
<
T
>
(
input
->
dims
(),
ctx
.
GetPlace
());
int64_t
*
med_ids_data
=
mediate_indices
.
mutable_data
<
int64_t
>
(
in_dims
,
ctx
.
GetPlace
());
// Target index of each element along the given axis in the mediate tensors
Tensor
trg_idx_t
;
int64_t
*
trg_idx
=
trg_idx_t
.
mutable_data
<
int64_t
>
(
in_dims
,
ctx
.
GetPlace
());
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
const
int
num_threads
=
PADDLE_CUDA_NUM_THREADS
;
ComputeTargetIdx
<<<
(
numel
-
1
)
/
num_threads
+
1
,
num_threads
,
0
,
stream
>>>
(
in_dims_data
,
in_dims
.
size
(),
axis
,
numel
,
trg_idx
,
med_ids_data
);
PermuteInData
<<<
(
numel
-
1
)
/
num_threads
+
1
,
num_threads
,
0
,
stream
>>>
(
in_data
,
trg_idx
,
numel
,
med_out_data
);
Sort
<<<
(
groups
-
1
)
/
num_threads
+
1
,
num_threads
,
0
,
stream
>>>
(
in_dims
[
axis
],
groups
,
med_out_data
,
med_ids_data
);
PermuteMediateData
<<<
(
numel
-
1
)
/
num_threads
+
1
,
num_threads
,
0
,
stream
>>>
(
med_out_data
,
med_ids_data
,
trg_idx
,
numel
,
out_data
,
ids_data
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
argsort
,
paddle
::
operators
::
ArgsortOpCUDAKernel
<
float
>
,
paddle
::
operators
::
ArgsortOpCUDAKernel
<
double
>
);
paddle/fluid/operators/argsort_op.h
0 → 100644
浏览文件 @
5f79c7fb
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
ArgsortKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
indices
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Indices"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
in_dims
=
input
->
dims
();
axis
=
(
axis
<
0
)
?
(
in_dims
.
size
()
+
axis
)
:
axis
;
const
T
*
in_data
=
input
->
data
<
T
>
();
T
*
out_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
*
ids_data
=
indices
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
());
int64_t
groups
=
input
->
numel
()
/
in_dims
[
axis
];
int64_t
stride
=
(
axis
==
in_dims
.
size
()
-
1
)
?
1
:
framework
::
product
(
framework
::
slice_ddim
(
in_dims
,
axis
+
1
,
in_dims
.
size
()));
for
(
int64_t
i
=
0
;
i
<
groups
;
++
i
)
{
int64_t
idx
=
i
;
std
::
vector
<
int64_t
>
shape_vec
(
in_dims
.
size
(),
0
);
for
(
int64_t
dim
=
in_dims
.
size
()
-
1
;
dim
>=
0
;
--
dim
)
{
if
(
dim
!=
axis
)
{
shape_vec
[
dim
]
=
idx
%
in_dims
[
dim
];
idx
/=
in_dims
[
dim
];
}
}
int64_t
start_index
=
shape_vec
[
0
];
for
(
int64_t
dim
=
0
;
dim
<
in_dims
.
size
()
-
1
;
++
dim
)
{
start_index
=
start_index
*
in_dims
[
dim
+
1
]
+
shape_vec
[
dim
+
1
];
}
std
::
vector
<
int64_t
>
org_index_vec
(
in_dims
[
axis
],
start_index
);
for
(
int64_t
j
=
1
;
j
<
in_dims
[
axis
];
++
j
)
{
org_index_vec
[
j
]
+=
j
*
stride
;
}
std
::
sort
(
org_index_vec
.
begin
(),
org_index_vec
.
end
(),
[
in_data
](
const
int64_t
v1
,
const
int64_t
v2
)
{
return
in_data
[
v1
]
<
in_data
[
v2
];
});
for
(
size_t
j
=
0
;
j
<
org_index_vec
.
size
();
++
j
)
{
int64_t
index
=
start_index
+
j
*
stride
;
out_data
[
index
]
=
in_data
[
org_index_vec
[
j
]];
ids_data
[
index
]
=
(
org_index_vec
[
j
]
-
start_index
)
/
stride
;
}
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/tensor.py
浏览文件 @
5f79c7fb
...
@@ -33,6 +33,7 @@ __all__ = [
...
@@ -33,6 +33,7 @@ __all__ = [
'fill_constant'
,
'fill_constant'
,
'argmin'
,
'argmin'
,
'argmax'
,
'argmax'
,
'argsort'
,
'ones'
,
'ones'
,
'zeros'
,
'zeros'
,
'reverse'
,
'reverse'
,
...
@@ -444,6 +445,58 @@ def argmax(x, axis=0):
...
@@ -444,6 +445,58 @@ def argmax(x, axis=0):
return
out
return
out
def
argsort
(
input
,
axis
=-
1
,
name
=
None
):
"""
Performs sorting on the input Variable along the given axis, and outputs
sorted data Varibale and its corresponding index Variable with the same
shape as :attr:`input`.
.. code-block:: text
For example, the given axis is -1 and the input Variable
input = [[0.15849551, 0.45865775, 0.8563702 ],
[0.12070083, 0.28766365, 0.18776911]],
after argsort, the sorted Vairable becomes
out = [[0.15849551, 0.45865775, 0.8563702 ],
[0.12070083, 0.18776911, 0.28766365]],
and the sorted indices along the given axis turn outs to be
indices = [[0, 1, 2],
[0, 2, 1]]
Args:
input(Variable): The input Variable for sorting.
axis(int): The axis along which to sort the input Variable. When
:attr:`axis` < 0, the actual axis will be :attr:`axis` +
rank(:attr:`input`). Default -1, the last dimension.
name(str|None): (optional) A name for this layer. If set None, the
layer will be named automatically.
Returns:
tuple: A tuple of sorted data Variable and the sorted indices.
Examples:
.. code-block:: python
input = fluid.layers.data(data=[2, 3])
out, indices = fluid.layers.argsort(input, axis=0)
"""
helper
=
LayerHelper
(
"argsort"
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
,
stop_gradient
=
True
)
ids
=
helper
.
create_tmp_variable
(
VarDesc
.
VarType
.
INT64
,
stop_gradient
=
True
)
helper
.
append_op
(
type
=
'argsort'
,
inputs
=
{
'X'
:
input
},
outputs
=
{
'Out'
:
out
,
'Indices'
:
ids
},
attrs
=
{
'axis'
:
axis
})
return
out
,
ids
def
ones
(
shape
,
dtype
,
force_cpu
=
False
):
def
ones
(
shape
,
dtype
,
force_cpu
=
False
):
"""
"""
**ones**
**ones**
...
...
python/paddle/fluid/tests/unittests/test_argsort_op.py
0 → 100644
浏览文件 @
5f79c7fb
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestArgsortOp
(
OpTest
):
def
setUp
(
self
):
self
.
init_axis
()
x
=
np
.
random
.
random
((
2
,
3
,
4
,
5
,
10
)).
astype
(
"float32"
)
if
self
.
axis
<
0
:
self
.
axis
=
self
.
axis
+
len
(
x
.
shape
)
self
.
indices
=
np
.
argsort
(
x
,
kind
=
'quicksort'
,
axis
=
self
.
axis
)
self
.
out
=
np
.
sort
(
x
,
kind
=
'quicksort'
,
axis
=
self
.
axis
)
self
.
op_type
=
"argsort"
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'axis'
:
self
.
axis
}
self
.
outputs
=
{
'Indices'
:
self
.
indices
,
'Out'
:
self
.
out
}
def
init_axis
(
self
):
self
.
axis
=
-
1
def
test_check_output
(
self
):
self
.
check_output
()
class
TestArgsortOpAxis0
(
TestArgsortOp
):
def
init_axis
(
self
):
self
.
axis
=
0
class
TestArgsortOpAxis1
(
TestArgsortOp
):
def
init_axis
(
self
):
self
.
axis
=
1
class
TestArgsortOpAxisNeg2
(
TestArgsortOp
):
def
init_axis
(
self
):
self
.
axis
=
-
2
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
5f79c7fb
...
@@ -419,6 +419,15 @@ class TestBook(unittest.TestCase):
...
@@ -419,6 +419,15 @@ class TestBook(unittest.TestCase):
self
.
assertIsNotNone
(
iou
)
self
.
assertIsNotNone
(
iou
)
print
(
str
(
program
))
print
(
str
(
program
))
def
test_argsort
(
self
):
program
=
Program
()
with
program_guard
(
program
):
data
=
layers
.
data
(
name
=
'x'
,
shape
=
[
2
,
3
,
3
],
dtype
=
"float32"
)
out
,
ids
=
layers
.
argsort
(
input
=
data
,
axis
=
1
)
self
.
assertIsNotNone
(
out
)
self
.
assertIsNotNone
(
ids
)
print
(
str
(
program
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录