Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
95b95a28
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
95b95a28
编写于
12月 09, 2019
作者:
W
wangchaochaohu
提交者:
GitHub
12月 09, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Mean gpu optimize (#21643)
* accelerate mean op test=develop
上级
48600d7f
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
87 addition
and
3 deletion
+87
-3
paddle/fluid/operators/mean_op.cu
paddle/fluid/operators/mean_op.cu
+87
-3
未找到文件。
paddle/fluid/operators/mean_op.cu
浏览文件 @
95b95a28
...
...
@@ -11,15 +11,99 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
DivideFunctor
{
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
(
static_cast
<
T
>
(
1.0
/
n
))
{}
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
private:
T
n_inv
;
};
template
<
typename
T
>
__global__
void
MeanRunKernel
(
const
T
in_data
,
T
*
out_data
,
int
N
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
out_data
[
idx
]
=
in_data
/
(
static_cast
<
T
>
(
N
));
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
MeanCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
size_prob
=
input
->
numel
();
const
T
*
in_data
=
input
->
data
<
T
>
();
T
*
out_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
stream
=
context
.
cuda_device_context
().
stream
();
DivideFunctor
<
T
>
transformer
(
size_prob
);
cub
::
TransformInputIterator
<
T
,
DivideFunctor
<
T
>
,
const
T
*>
trans_x
(
in_data
,
transformer
);
size_t
temp_storage_bytes
=
0
;
auto
err
=
cub
::
DeviceReduce
::
Sum
(
nullptr
,
temp_storage_bytes
,
trans_x
,
out_data
,
size_prob
,
stream
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
,
"MeanOP failed to get reduce workspace size"
,
cudaGetErrorString
(
err
));
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
context
.
GetPlace
());
err
=
cub
::
DeviceReduce
::
Sum
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
out_data
,
size_prob
,
stream
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
,
"MeanOP failed to run reduce computation"
,
cudaGetErrorString
(
err
));
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
MeanCUDAGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
OG
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
PADDLE_ENFORCE_EQ
(
OG
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Mean Gradient Input Tensor len should be 1. But received %d"
,
OG
->
numel
()));
auto
IG
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
IG
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
in_data
=
OG
[
0
];
auto
size_prob
=
IG
->
numel
();
auto
out_data
=
IG
->
data
<
T
>
();
int
threads
=
512
;
int
grid
=
(
size_prob
+
threads
-
1
)
/
threads
;
auto
stream
=
context
.
cuda_device_context
().
stream
();
MeanRunKernel
<
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
in_data
,
out_data
,
size_prob
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
mean
,
ops
::
MeanKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MeanKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MeanKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
mean
,
ops
::
Mean
CUDA
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
Mean
CUDA
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
Mean
CUDA
Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
mean_grad
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录