Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
42645ff7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
42645ff7
编写于
6月 12, 2018
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Compute target index on gpu
上级
6ee22c4f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
52 addition
and
39 deletion
+52
-39
paddle/fluid/operators/argsort_op.cc
paddle/fluid/operators/argsort_op.cc
+1
-1
paddle/fluid/operators/argsort_op.cu
paddle/fluid/operators/argsort_op.cu
+51
-38
未找到文件。
paddle/fluid/operators/argsort_op.cc
浏览文件 @
42645ff7
...
@@ -30,7 +30,7 @@ class ArgsortOp : public framework::OperatorWithKernel {
...
@@ -30,7 +30,7 @@ class ArgsortOp : public framework::OperatorWithKernel {
"Output(Indices) of ArgsortOp should not be null."
);
"Output(Indices) of ArgsortOp should not be null."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
int
axis
=
static_cast
<
int
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
)
);
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
auto
num_dims
=
in_dims
.
size
();
auto
num_dims
=
in_dims
.
size
();
PADDLE_ENFORCE
(
axis
<
num_dims
,
PADDLE_ENFORCE
(
axis
<
num_dims
,
...
...
paddle/fluid/operators/argsort_op.cu
浏览文件 @
42645ff7
...
@@ -26,6 +26,42 @@ namespace operators {
...
@@ -26,6 +26,42 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
__global__
void
ComputeTargetIdx
(
const
int64_t
*
in_dims
,
int
dims_size
,
int
axis
,
int64_t
n
,
int64_t
*
trg_idx
,
int64_t
*
med_ids
)
{
int64_t
index
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
index
<
n
)
{
int64_t
*
shape_out_axis
=
new
int64_t
[
dims_size
-
1
];
int64_t
*
dims_out_axis
=
new
int64_t
[
dims_size
-
1
];
int64_t
tmp
=
index
;
int64_t
pos_in_axis
=
0
;
int64_t
i
=
dims_size
-
2
;
int64_t
dim_axis
=
0
;
for
(
int64_t
j
=
dims_size
-
1
;
j
>=
0
;
--
j
)
{
int64_t
dim
=
in_dims
[
j
];
if
(
j
!=
axis
)
{
shape_out_axis
[
i
]
=
tmp
%
dim
;
dims_out_axis
[
i
]
=
dim
;
i
--
;
}
else
{
dim_axis
=
dim
;
pos_in_axis
=
tmp
%
dim_axis
;
}
tmp
/=
dim
;
}
int64_t
group
=
(
dims_size
>
1
)
?
shape_out_axis
[
0
]
:
0
;
for
(
int64_t
j
=
0
;
j
<
dims_size
-
2
;
++
j
)
{
group
=
group
*
dims_out_axis
[
j
+
1
]
+
shape_out_axis
[
j
+
1
];
}
int64_t
traget_idx
=
group
*
dim_axis
+
pos_in_axis
;
trg_idx
[
index
]
=
traget_idx
;
med_ids
[
traget_idx
]
=
pos_in_axis
;
delete
[]
shape_out_axis
;
delete
[]
dims_out_axis
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
PermuteInData
(
const
T
*
in
,
const
int64_t
*
trg_idx
,
int64_t
n
,
__global__
void
PermuteInData
(
const
T
*
in
,
const
int64_t
*
trg_idx
,
int64_t
n
,
T
*
med_out
)
{
T
*
med_out
)
{
...
@@ -76,50 +112,27 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -76,50 +112,27 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
int64_t
numel
=
input
->
numel
();
int64_t
numel
=
input
->
numel
();
int64_t
groups
=
numel
/
in_dims
[
axis
];
int64_t
groups
=
numel
/
in_dims
[
axis
];
// Mediate tensor for sorting
std
::
vector
<
int64_t
>
in_dims_vec
=
vectorize
(
in_dims
);
Tensor
mediate_output
;
thrust
::
device_vector
<
int64_t
>
in_dims_dev
(
in_dims_vec
.
begin
(),
in_dims_vec
.
end
());
int64_t
*
in_dims_data
=
thrust
::
raw_pointer_cast
(
in_dims_dev
.
data
());
// Mediate tensor for sorting data and indices
Tensor
mediate_output
,
mediate_indices
;
T
*
med_out_data
=
T
*
med_out_data
=
mediate_output
.
mutable_data
<
T
>
(
input
->
dims
(),
ctx
.
GetPlace
());
mediate_output
.
mutable_data
<
T
>
(
input
->
dims
(),
ctx
.
GetPlace
());
int64_t
*
med_ids_data
=
// The target index of each elemement in mediate tensor
mediate_indices
.
mutable_data
<
int64_t
>
(
in_dims
,
ctx
.
GetPlace
());
std
::
vector
<
int64_t
>
target_idx
(
numel
,
0
);
// Target index of each element along the given axis in the mediate tensors
// To record the index along the given axis for the data in mediate tensor
Tensor
trg_idx_t
;
std
::
vector
<
int64_t
>
mediate_indices
(
numel
,
0
);
int64_t
*
trg_idx
=
trg_idx_t
.
mutable_data
<
int64_t
>
(
in_dims
,
ctx
.
GetPlace
());
std
::
vector
<
int64_t
>
in_dims_out_axis
=
vectorize
(
in_dims
);
in_dims_out_axis
.
erase
(
in_dims_out_axis
.
begin
()
+
axis
);
for
(
int64_t
index
=
0
;
index
<
numel
;
++
index
)
{
int64_t
tmp
=
index
;
int64_t
pos_in_axis
=
0
;
std
::
vector
<
int64_t
>
shape
;
for
(
int64_t
j
=
in_dims
.
size
()
-
1
;
j
>=
0
;
--
j
)
{
if
(
j
!=
axis
)
{
shape
.
push_back
(
tmp
%
in_dims
[
j
]);
}
else
{
pos_in_axis
=
tmp
%
in_dims
[
j
];
}
tmp
/=
in_dims
[
j
];
}
std
::
reverse
(
shape
.
begin
(),
shape
.
end
());
int64_t
group
=
(
shape
.
size
()
>
0
)
?
shape
[
0
]
:
0
;
for
(
size_t
j
=
0
;
j
<
shape
.
size
()
-
1
;
++
j
)
{
group
=
group
*
in_dims_out_axis
[
j
+
1
]
+
shape
[
j
+
1
];
}
target_idx
[
index
]
=
group
*
in_dims
[
axis
]
+
pos_in_axis
;
mediate_indices
[
target_idx
[
index
]]
=
pos_in_axis
;
}
thrust
::
device_vector
<
int64_t
>
med_ids_dev
(
mediate_indices
.
begin
(),
mediate_indices
.
end
());
int64_t
*
med_ids_data
=
thrust
::
raw_pointer_cast
(
med_ids_dev
.
data
());
thrust
::
device_vector
<
int64_t
>
trg_idx_dev
(
target_idx
.
begin
(),
target_idx
.
end
());
int64_t
*
trg_idx
=
thrust
::
raw_pointer_cast
(
trg_idx_dev
.
data
());
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
ctx
.
device_context
())
.
stream
();
.
stream
();
auto
num_threads
=
PADDLE_CUDA_NUM_THREADS
;
int
num_threads
=
PADDLE_CUDA_NUM_THREADS
;
ComputeTargetIdx
<<<
(
numel
-
1
)
/
num_threads
+
1
,
num_threads
,
0
,
stream
>>>
(
in_dims_data
,
in_dims
.
size
(),
axis
,
numel
,
trg_idx
,
med_ids_data
);
PermuteInData
<<<
(
numel
-
1
)
/
num_threads
+
1
,
num_threads
,
0
,
stream
>>>
(
PermuteInData
<<<
(
numel
-
1
)
/
num_threads
+
1
,
num_threads
,
0
,
stream
>>>
(
in_data
,
trg_idx
,
numel
,
med_out_data
);
in_data
,
trg_idx
,
numel
,
med_out_data
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录