Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f00f982a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f00f982a
编写于
8月 20, 2020
作者:
Z
Zhaolong Xing
提交者:
GitHub
8月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cub impl for arg max, min (#25941)
test=develop
上级
57d434df
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
206 addition
and
58 deletion
+206
-58
paddle/fluid/operators/arg_max_op.cu
paddle/fluid/operators/arg_max_op.cu
+22
-29
paddle/fluid/operators/arg_min_max_op_base.cu.h
paddle/fluid/operators/arg_min_max_op_base.cu.h
+163
-0
paddle/fluid/operators/arg_min_op.cu
paddle/fluid/operators/arg_min_op.cu
+21
-29
未找到文件。
paddle/fluid/operators/arg_max_op.cu
浏览文件 @
f00f982a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h"
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
arg_max
,
arg_max
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
float
,
cub
::
ArgMax
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
double
,
cub
::
ArgMax
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int64_t
,
cub
::
ArgMax
>
,
double
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int32_t
,
cub
::
ArgMax
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int8_t
,
cub
::
ArgMax
>
);
int64_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int32_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int16_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
);
paddle/fluid/operators/arg_min_max_op_base.cu.h
0 → 100644
浏览文件 @
f00f982a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef __NVCC__
#include <cub/cub.cuh>
#include <limits>
#include <string>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
namespace
{
// NOLINT
template
<
typename
K
,
typename
V
>
using
KeyValuePair
=
cub
::
KeyValuePair
<
K
,
V
>
;
using
Tensor
=
framework
::
Tensor
;
}
// end namespace
#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \
case
(
1
<<
(
log2_block_dim
)):
{
\
constexpr
auto
kBlockDim
=
(
1
<<
(
log2_block_dim
));
\
__VA_ARGS__
;
\
}
break
#define FIXED_BLOCK_DIM_CASE(...) \
FIXED_BLOCK_DIM_CASE_BASE
(
10
,
##
__VA_ARGS__
);
\
FIXED_BLOCK_DIM_CASE_BASE
(
9
,
##
__VA_ARGS__
);
\
FIXED_BLOCK_DIM_CASE_BASE
(
8
,
##
__VA_ARGS__
);
\
FIXED_BLOCK_DIM_CASE_BASE
(
7
,
##
__VA_ARGS__
);
\
FIXED_BLOCK_DIM_CASE_BASE
(
6
,
##
__VA_ARGS__
);
\
FIXED_BLOCK_DIM_CASE_BASE
(
5
,
##
__VA_ARGS__
);
\
FIXED_BLOCK_DIM_CASE_BASE
(
4
,
##
__VA_ARGS__
);
\
FIXED_BLOCK_DIM_CASE_BASE
(
3
,
##
__VA_ARGS__
);
template
<
typename
T
,
typename
IndType
,
class
Reducer
,
size_t
BlockDim
>
__global__
void
ArgCUDAKernel
(
const
IndType
height
,
// n * h
const
IndType
width
,
// c
const
IndType
post_size
,
// h
const
Reducer
reducer
,
const
T
init
,
const
T
*
in
,
IndType
*
out
)
{
typedef
cub
::
BlockReduce
<
KeyValuePair
<
int
,
T
>
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
for
(
int
idx
=
blockIdx
.
x
;
idx
<
height
;
idx
+=
gridDim
.
x
)
{
KeyValuePair
<
int
,
T
>
kv_pair
=
{
-
1
,
init
};
int
h
=
idx
/
post_size
;
int
w
=
idx
%
post_size
;
for
(
int
k
=
threadIdx
.
x
;
k
<
width
;
k
+=
blockDim
.
x
)
{
kv_pair
=
reducer
({
k
,
in
[
h
*
width
*
post_size
+
k
*
post_size
+
w
]},
kv_pair
);
}
kv_pair
=
BlockReduce
(
temp_storage
).
Reduce
(
kv_pair
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
out
[
idx
]
=
static_cast
<
IndType
>
(
kv_pair
.
key
);
}
__syncthreads
();
}
}
template
<
typename
T
,
typename
IndType
,
class
Reducer
>
void
ComputeFullArg
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
indices
,
const
IndType
pre
,
const
IndType
post
,
const
IndType
n
)
{
auto
cu_stream
=
ctx
.
stream
();
auto
ComputeBlockSize
=
[](
IndType
col
)
{
if
(
col
>
512
)
return
1024
;
else
if
(
col
>
256
)
return
512
;
else
if
(
col
>
128
)
return
256
;
else
if
(
col
>
64
)
return
128
;
else
if
(
col
>
32
)
return
64
;
else
if
(
col
>
16
)
return
32
;
else
if
(
col
>
8
)
return
16
;
else
return
8
;
};
int
max_grid_dimx
=
ctx
.
GetCUDAMaxGridDimSize
().
x
;
int
height
=
pre
*
post
;
int
width
=
n
;
int
grid_size
=
height
<
max_grid_dimx
?
height
:
max_grid_dimx
;
const
T
*
in_data
=
input
.
data
<
T
>
();
IndType
*
out_data
=
indices
->
mutable_data
<
IndType
>
(
ctx
.
GetPlace
());
if
(
typeid
(
Reducer
)
==
typeid
(
cub
::
ArgMax
))
{
switch
(
ComputeBlockSize
(
width
))
{
FIXED_BLOCK_DIM_CASE
(
ArgCUDAKernel
<
T
,
IndType
,
Reducer
,
kBlockDim
><<<
grid_size
,
kBlockDim
,
0
,
cu_stream
>>>
(
height
,
width
,
post
,
Reducer
(),
std
::
numeric_limits
<
T
>::
lowest
(),
in_data
,
out_data
));
}
}
else
{
switch
(
ComputeBlockSize
(
width
))
{
FIXED_BLOCK_DIM_CASE
(
ArgCUDAKernel
<
T
,
IndType
,
Reducer
,
kBlockDim
><<<
grid_size
,
kBlockDim
,
0
,
cu_stream
>>>
(
height
,
width
,
post
,
Reducer
(),
std
::
numeric_limits
<
T
>::
max
(),
in_data
,
out_data
));
}
}
}
template
<
typename
T
,
class
Reducer
>
class
ArgMinMaxOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
axis
=
ctx
.
Attr
<
int64_t
>
(
"axis"
);
auto
in_dims
=
input
->
dims
();
axis
=
(
axis
<
0
)
?
(
in_dims
.
size
()
+
axis
)
:
axis
;
int64_t
numel
=
input
->
numel
();
int64_t
groups
=
numel
/
in_dims
[
axis
];
int64_t
pre
=
1
;
int64_t
post
=
1
;
int64_t
n
=
in_dims
[
axis
];
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
pre
*=
in_dims
[
i
];
}
for
(
int
i
=
axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
post
*=
in_dims
[
i
];
}
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
ComputeFullArg
<
T
,
int64_t
,
Reducer
>
(
dev_ctx
,
*
input
,
output
,
pre
,
post
,
n
);
}
};
#endif
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/arg_min_op.cu
浏览文件 @
f00f982a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h"
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
arg_min
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
float
,
cub
::
ArgMin
>
,
arg_min
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
double
,
cub
::
ArgMin
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int64_t
,
cub
::
ArgMin
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int32_t
,
cub
::
ArgMin
>
,
double
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int8_t
,
cub
::
ArgMin
>
);
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int32_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int16_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录