Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1b69e528
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1b69e528
编写于
12月 15, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
12月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize for long width for elementwise (#29602)
上级
78dad786
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
93 addition
and
3 deletion
+93
-3
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+93
-3
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
1b69e528
...
@@ -19,6 +19,11 @@ limitations under the License. */
...
@@ -19,6 +19,11 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -121,6 +126,20 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
...
@@ -121,6 +126,20 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#ifdef __NVCC__
#ifdef __NVCC__
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
inline
int
VectorizedSize
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
if
(
address
%
vec4
==
0
)
{
return
4
;
}
return
1
;
}
template
<
typename
T
,
int
BLOCK_W
,
int
BLOCK_H
>
template
<
typename
T
,
int
BLOCK_W
,
int
BLOCK_H
>
__global__
void
MatrixColReduce
(
const
T
*
__restrict__
in
,
T
*
__restrict__
out
,
__global__
void
MatrixColReduce
(
const
T
*
__restrict__
in
,
T
*
__restrict__
out
,
size_t
width
,
size_t
height
)
{
size_t
width
,
size_t
height
)
{
...
@@ -200,6 +219,45 @@ __global__ void FP16MatrixColReduce(
...
@@ -200,6 +219,45 @@ __global__ void FP16MatrixColReduce(
if
((
threadIdx
.
y
==
0
)
&&
((
w
)
<
width
))
out
[
w
]
=
sdata
[
0
][
threadIdx
.
x
];
if
((
threadIdx
.
y
==
0
)
&&
((
w
)
<
width
))
out
[
w
]
=
sdata
[
0
][
threadIdx
.
x
];
}
}
}
}
template
<
typename
T
>
__global__
void
MatrixReduceLongWidth
(
const
T
*
__restrict__
in
,
T
*
out
,
size_t
width
,
size_t
height
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(;
idx
<
width
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
row
=
0
;
row
<
height
;
row
++
)
{
sum
+=
in
[
idx
+
row
*
width
];
}
out
[
idx
]
=
sum
;
}
}
template
<
typename
T
,
int
VEC_SIZE
>
__global__
void
VecMatrixReduceLongWidth
(
const
T
*
__restrict__
in
,
T
*
out
,
size_t
width
,
size_t
height
)
{
using
LoadT
=
AlignedVector
<
T
,
VEC_SIZE
>
;
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
w
=
idx
*
VEC_SIZE
;
int
width_stride
=
blockDim
.
x
*
gridDim
.
x
*
VEC_SIZE
;
for
(;
w
<
width
;
w
+=
width
)
{
T
zero
=
static_cast
<
T
>
(
0
);
T
sum
[
VEC_SIZE
]
=
{
zero
};
T
tmp_vec
[
VEC_SIZE
]
=
{
zero
};
LoadT
*
tmp_ptr
=
reinterpret_cast
<
LoadT
*>
(
&
tmp_vec
);
for
(
int
row
=
0
;
row
<
height
;
row
++
)
{
int
offset
=
width
*
row
+
w
;
*
tmp_ptr
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
in
[
offset
]);
for
(
int
v
=
0
;
v
<
VEC_SIZE
;
v
++
)
{
sum
[
v
]
+=
tmp_vec
[
v
];
}
}
for
(
int
v
=
0
;
v
<
VEC_SIZE
;
v
++
)
out
[
w
+
v
]
=
sum
[
v
];
}
}
#endif
#endif
#endif
#endif
bool
static
RunSpecialDims
(
const
framework
::
DDim
&
dx_dims
,
bool
static
RunSpecialDims
(
const
framework
::
DDim
&
dx_dims
,
...
@@ -301,6 +359,21 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
...
@@ -301,6 +359,21 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
*
dout
,
ctx
.
GetPlace
(),
*
dout
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dy
);
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
dy
);
}
}
// special optimization using cub
if
(
width
==
1
)
{
int
nums
=
height
;
size_t
temp_storage_bytes
=
0
;
auto
err
=
cub
::
DeviceReduce
::
Sum
(
nullptr
,
temp_storage_bytes
,
dout_data
,
out_data
,
nums
,
stream
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
ctx
.
GetPlace
());
err
=
cub
::
DeviceReduce
::
Sum
(
temp_storage
,
temp_storage_bytes
,
dout_data
,
out_data
,
nums
,
stream
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
);
}
constexpr
int
block_x
=
32
;
constexpr
int
block_x
=
32
;
constexpr
int
block_y
=
32
;
constexpr
int
block_y
=
32
;
...
@@ -311,7 +384,8 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
...
@@ -311,7 +384,8 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
int
max_blocks
=
std
::
max
(
max_physical_threads
/
(
block_x
*
block_y
),
1
);
int
max_blocks
=
std
::
max
(
max_physical_threads
/
(
block_x
*
block_y
),
1
);
int
theory_block
=
(
width
+
blocks
.
x
-
1
)
/
blocks
.
x
;
int
theory_block
=
(
width
+
blocks
.
x
-
1
)
/
blocks
.
x
;
dim3
grids
(
std
::
min
(
theory_block
,
max_blocks
));
dim3
grids
(
std
::
min
(
theory_block
,
max_blocks
));
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
&&
(
width
/
height
)
<
32
)
{
const
paddle
::
platform
::
float16
*
ptr1
=
const
paddle
::
platform
::
float16
*
ptr1
=
reinterpret_cast
<
const
paddle
::
platform
::
float16
*>
(
dout_data
);
reinterpret_cast
<
const
paddle
::
platform
::
float16
*>
(
dout_data
);
paddle
::
platform
::
float16
*
ptr2
=
paddle
::
platform
::
float16
*
ptr2
=
...
@@ -325,8 +399,24 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
...
@@ -325,8 +399,24 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
}
}
return
;
return
;
}
}
MatrixColReduce
<
T
,
block_x
,
block_y
><<<
grids
,
blocks
,
0
,
stream
>>>
(
dout_data
,
out_data
,
width
,
height
);
if
(
width
/
height
<
32
)
{
MatrixColReduce
<
T
,
block_x
,
block_y
><<<
grids
,
blocks
,
0
,
stream
>>>
(
dout_data
,
out_data
,
width
,
height
);
}
else
{
size_t
thread_nums
=
1024
;
size_t
block_nums
=
(
width
+
thread_nums
-
1
)
/
thread_nums
;
int
vec_size
=
VectorizedSize
<
T
>
(
dx_data
);
if
(
vec_size
==
4
&&
width
%
4
==
0
)
{
block_nums
=
(
width
/
vec_size
+
thread_nums
-
1
)
/
thread_nums
;
VecMatrixReduceLongWidth
<
T
,
4
><<<
block_nums
,
thread_nums
,
0
,
stream
>>>
(
dout_data
,
out_data
,
width
,
height
);
}
else
{
MatrixReduceLongWidth
<
T
><<<
block_nums
,
thread_nums
,
0
,
stream
>>>
(
dout_data
,
out_data
,
width
,
height
);
}
}
return
;
return
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录