Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3ff5cc2d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
3ff5cc2d
编写于
11月 20, 2019
作者:
Z
zhaoyuchen2018
提交者:
GitHub
11月 20, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix topk compile failed on windows (#21243)
* Fix topk compile failed on windows * Use explicit cast for assign data
上级
2e2f92a5
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
18 addition
and
15 deletion
+18
-15
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+18
-15
未找到文件。
paddle/fluid/operators/top_k_op.cu
浏览文件 @
3ff5cc2d
...
...
@@ -336,12 +336,13 @@ struct ColumnIndexIter {
int
num_cols_
;
};
__global__
void
InitIndex
(
int64_t
*
indices
,
int
num_rows
,
int
num_cols
)
{
__global__
void
InitIndex
(
int64_t
*
indices
,
int64_t
num_rows
,
int64_t
num_cols
)
{
int
col_id
=
threadIdx
.
x
;
int
row_id
=
blockIdx
.
x
;
for
(
int
j
=
row_id
;
j
<
num_rows
;
j
+=
gridDim
.
x
)
{
for
(
int
i
=
col_id
;
i
<
num_cols
;
i
+=
blockDim
.
x
)
{
for
(
int
64_t
j
=
row_id
;
j
<
num_rows
;
j
+=
gridDim
.
x
)
{
for
(
int
64_t
i
=
col_id
;
i
<
num_cols
;
i
+=
blockDim
.
x
)
{
indices
[
j
*
num_cols
+
i
]
=
i
;
}
}
...
...
@@ -349,14 +350,14 @@ __global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) {
template
<
typename
T
>
bool
SortTopk
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
*
input_tensor
,
const
size_t
num_cols
,
const
size_t
num_rows
,
size_t
k
,
framework
::
Tensor
*
out_tensor
,
const
framework
::
Tensor
*
input_tensor
,
const
int64_t
num_cols
,
const
int64_t
num_rows
,
const
int
k
,
framework
::
Tensor
*
out_tensor
,
framework
::
Tensor
*
indices_tensor
)
{
auto
cu_stream
=
ctx
.
stream
();
Tensor
input_indices
;
const
std
::
vector
<
int64_t
>
dims
=
{
static_cast
<
int64_t
>
(
num_rows
),
static_cast
<
int64_t
>
(
num_cols
)};
const
std
::
vector
<
int64_t
>
dims
=
{
num_rows
,
num_cols
};
auto
dim
=
framework
::
make_ddim
(
dims
);
input_indices
.
Resize
(
dim
);
// input_indices.Resize(num_rows*num_cols);
...
...
@@ -378,18 +379,20 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
int
block_size
=
ComputeBlockSize
(
num_cols
);
int
maxGridDimX
=
ctx
.
GetCUDAMaxGridDimSize
().
x
;
unsigned
int
maxGridDimX
=
ctx
.
GetCUDAMaxGridDimSize
().
x
;
// actually, int num_rows < max_grid_size
int
grid_size
=
num_rows
<
maxGridDimX
?
num_rows
:
maxGridDimX
;
unsigned
int
grid_size
=
num_rows
<
maxGridDimX
?
static_cast
<
unsigned
int
>
(
num_rows
)
:
maxGridDimX
;
// Init a index array
InitIndex
<<<
grid_size
,
block_size
,
0
,
cu_stream
>>>
(
input_indices
.
data
<
int64_t
>
(),
num_rows
,
num_cols
);
// create iter for counting input
cub
::
CountingInputIterator
<
int
>
counting_iter
(
0
);
cub
::
CountingInputIterator
<
int
64_t
>
counting_iter
(
0
);
// segment_offset is used for move to next row
cub
::
TransformInputIterator
<
int
,
SegmentOffsetIter
,
cub
::
CountingInputIterator
<
int
>>
cub
::
TransformInputIterator
<
int
64_t
,
SegmentOffsetIter
,
cub
::
CountingInputIterator
<
int
64_t
>>
segment_offsets_t
(
counting_iter
,
SegmentOffsetIter
(
num_cols
));
T
*
sorted_values_ptr
;
...
...
@@ -484,7 +487,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
indices
=
ctx
.
Output
<
Tensor
>
(
"Indices"
);
size_
t
k
=
static_cast
<
int
>
(
ctx
.
Attr
<
int
>
(
"k"
));
in
t
k
=
static_cast
<
int
>
(
ctx
.
Attr
<
int
>
(
"k"
));
auto
*
k_t
=
ctx
.
Input
<
Tensor
>
(
"K"
);
if
(
k_t
)
{
...
...
@@ -502,9 +505,9 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// FIXME(typhoonzero): data is always converted to type T?
framework
::
DDim
inputdims
=
input
->
dims
();
const
size
_t
input_height
=
framework
::
product
(
const
int64
_t
input_height
=
framework
::
product
(
framework
::
slice_ddim
(
inputdims
,
0
,
inputdims
.
size
()
-
1
));
const
size
_t
input_width
=
inputdims
[
inputdims
.
size
()
-
1
];
const
int64
_t
input_width
=
inputdims
[
inputdims
.
size
()
-
1
];
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
if
((
input_width
<=
1024
||
k
>=
128
||
k
==
input_width
))
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录