Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5970871a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5970871a
编写于
8月 05, 2020
作者:
Z
Zhaolong Xing
提交者:
GitHub
8月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add eltwise clip cuda impl. (#25689)
test=develop
上级
36027490
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
70 addition
and
9 deletion
+70
-9
paddle/fluid/operators/clip_op.h
paddle/fluid/operators/clip_op.h
+26
-9
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+44
-0
未找到文件。
paddle/fluid/operators/clip_op.h
浏览文件 @
5970871a
...
...
@@ -25,17 +25,23 @@ namespace operators {
using
framework
::
Tensor
;
using
platform
::
Transform
;
#ifdef __NVCC__
template
<
typename
T
,
typename
UnaryOperation
>
__global__
void
ClipCudaKernel
(
const
T
*
input
,
T
*
out
,
int
num
,
UnaryOperation
op
)
{
int
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
num
)
{
out
[
idx
]
=
op
(
input
[
idx
]);
}
}
#endif
template
<
typename
T
>
class
ClipFunctor
{
public:
explicit
ClipFunctor
(
const
T
min
,
const
T
max
)
:
min_
(
min
),
max_
(
max
)
{}
HOSTDEVICE
T
operator
()(
const
T
&
x
)
const
{
if
(
x
<
min_
)
return
min_
;
else
if
(
x
>
max_
)
return
max_
;
else
return
x
;
return
x
<
min_
?
min_
:
x
>
max_
?
max_
:
x
;
}
private:
...
...
@@ -97,9 +103,20 @@ class ClipKernel : public framework::OpKernel<T> {
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
int64_t
numel
=
x
->
numel
();
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
()))
{
#ifdef __NVCC__
int
threads
=
256
;
int
blocks
=
(
numel
+
threads
-
1
)
/
threads
;
ClipCudaKernel
<
T
,
ClipFunctor
<
T
>><<<
blocks
,
threads
,
0
,
context
.
template
device_context
<
platform
::
CUDADeviceContext
>()
.
stream
()
>>>
(
x_data
,
out_data
,
numel
,
ClipFunctor
<
T
>
(
min
,
max
));
#endif
}
else
{
Transform
<
DeviceContext
>
trans
;
trans
(
context
.
template
device_context
<
DeviceContext
>(),
x_data
,
x_data
+
numel
,
out_data
,
ClipFunctor
<
T
>
(
min
,
max
));
}
}
else
if
(
x_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
x
=
context
.
Input
<
framework
::
SelectedRows
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
SelectedRows
>
(
"Out"
);
...
...
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
5970871a
...
...
@@ -197,6 +197,40 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
}
#ifdef __NVCC__
template
<
typename
Functor
,
typename
T
,
typename
OutType
>
__global__
void
ElementwiseKernel
(
const
T
*
x
,
const
T
*
y
,
OutType
*
out
,
int
pre
,
int
n
,
int
post
,
int
total
,
Functor
func
)
{
int
tid
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
idx
=
tid
/
post
%
n
;
if
(
tid
<
total
)
{
out
[
tid
]
=
func
(
x
[
tid
],
y
[
idx
]);
}
}
template
<
typename
Functor
,
typename
T
,
typename
OutType
>
void
ComputeElementwiseCUDA
(
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
,
int
pre
,
int
n
,
int
post
,
const
platform
::
CUDADeviceContext
&
ctx
,
Functor
func
,
const
bool
is_xsize_larger
=
true
)
{
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
y_data
=
y
->
data
<
T
>
();
OutType
*
out_data
=
z
->
mutable_data
<
OutType
>
(
ctx
.
GetPlace
());
int
numel
=
pre
*
n
*
post
;
int
threads
=
256
;
int
blocks
=
(
numel
+
threads
-
1
)
/
threads
;
if
(
is_xsize_larger
)
{
ElementwiseKernel
<
Functor
,
T
,
OutType
><<<
blocks
,
threads
,
0
,
ctx
.
stream
()
>>>
(
x_data
,
y_data
,
out_data
,
pre
,
n
,
post
,
numel
,
func
);
}
else
{
ElementwiseKernel
<
Functor
,
T
,
OutType
><<<
blocks
,
threads
,
0
,
ctx
.
stream
()
>>>
(
y_data
,
x_data
,
out_data
,
pre
,
n
,
post
,
numel
,
func
);
}
}
template
<
typename
Functor
,
typename
T
,
typename
OutType
=
T
>
__global__
void
CommonForwardBroadcastCUDAKernel
(
const
int
*
x_strides_array
,
const
int
*
y_strides_array
,
...
...
@@ -1908,6 +1942,16 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
ctx
,
x
,
y
,
z
,
x_dims
,
y_dims
,
func
,
axis
,
is_xsize_larger
);
return
;
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ComputeElementwiseCUDA
<
Functor
,
T
,
OutType
>
(
x
,
y
,
z
,
pre
,
n
,
post
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
func
,
is_xsize_larger
);
#endif
return
;
}
if
(
post
==
1
)
{
functor
.
RunRowWise
(
n
,
pre
);
return
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录