Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d6038c22
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d6038c22
编写于
2月 24, 2022
作者:
L
Li Min
提交者:
GitHub
2月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize performance of lookup_table_v2_op (#39856)
* optimize block config and fp16 atomicAdd perf for lookup_table_v2_grad.
上级
76a6b88d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
115 addition
and
18 deletion
+115
-18
paddle/fluid/operators/lookup_table_v2_op.cu
paddle/fluid/operators/lookup_table_v2_op.cu
+27
-18
paddle/fluid/platform/device/gpu/gpu_primitives.h
paddle/fluid/platform/device/gpu/gpu_primitives.h
+88
-0
未找到文件。
paddle/fluid/operators/lookup_table_v2_op.cu
浏览文件 @
d6038c22
...
...
@@ -21,19 +21,18 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
IdT
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
,
bool
PaddingFlag
>
template
<
typename
T
,
typename
IdT
,
bool
PaddingFlag
>
__global__
void
LookupTableV2
(
T
*
output
,
const
T
*
table
,
const
IdT
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
,
const
int64_t
padding_idx
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
gridDim
.
x
;
while
(
idy
<
K
)
{
auto
id
=
static_cast
<
int64_t
>
(
ids
[
idy
]);
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
blockDim
.
x
)
{
if
(
PaddingFlag
)
{
if
(
id
==
padding_idx
)
out
[
i
]
=
static_cast
<
T
>
(
0
);
...
...
@@ -43,25 +42,29 @@ __global__ void LookupTableV2(T *output, const T *table, const IdT *ids,
out
[
i
]
=
tab
[
i
];
}
}
idy
+=
BlockDimY
*
GridDimX
;
idy
+=
blockDim
.
y
*
gridDim
.
x
;
}
}
template
<
typename
T
,
typename
IdT
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
template
<
typename
T
,
typename
IdT
>
__global__
void
LookupTableV2Grad
(
T
*
table
,
const
T
*
output
,
const
IdT
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
gridDim
.
x
;
while
(
idy
<
K
)
{
auto
id
=
static_cast
<
int64_t
>
(
ids
[
idy
]);
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
#ifdef PADDLE_WITH_CUDA
paddle
::
platform
::
VectorizedAtomicAddPerBlock
(
D
,
idx
,
blockDim
.
x
,
out
,
tab
);
#else
for
(
int
i
=
idx
;
i
<
D
;
i
+=
blockDim
.
x
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
}
idy
+=
BlockDimY
*
GridDimX
;
#endif
idy
+=
blockDim
.
y
*
gridDim
.
x
;
}
}
...
...
@@ -81,8 +84,9 @@ struct LookupTableV2CUDAFunctor {
size_t
D
=
table_t
->
dims
()[
1
];
size_t
K
=
ids_t_
->
numel
();
const
int
gridx
=
2
*
context_
.
cuda_device_context
().
GetSMCount
();
dim3
threads
(
256
,
4
);
dim3
grids
(
80
,
1
);
dim3
grids
(
gridx
,
1
);
const
auto
*
table
=
table_t
->
template
data
<
T
>();
const
auto
*
ids
=
ids_t_
->
template
data
<
IdT
>();
...
...
@@ -90,10 +94,10 @@ struct LookupTableV2CUDAFunctor {
auto
stream
=
context_
.
cuda_device_context
().
stream
();
if
(
padding_idx
==
-
1
)
{
LookupTableV2
<
T
,
IdT
,
256
,
4
,
80
,
false
><<<
grids
,
threads
,
0
,
stream
>>>
(
LookupTableV2
<
T
,
IdT
,
false
><<<
grids
,
threads
,
0
,
stream
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
}
else
{
LookupTableV2
<
T
,
IdT
,
256
,
4
,
80
,
true
><<<
grids
,
threads
,
0
,
stream
>>>
(
LookupTableV2
<
T
,
IdT
,
true
><<<
grids
,
threads
,
0
,
stream
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
}
}
...
...
@@ -193,17 +197,22 @@ struct LookupTableV2GradCUDAFunctor {
int
D
=
d_table_t
->
dims
()[
1
];
int
K
=
ids_t_
->
numel
();
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
const
T
*
d_output
=
d_output_t
->
template
data
<
T
>();
const
auto
*
ids
=
ids_t_
->
template
data
<
IdT
>();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context_
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
t
.
device
(
*
dev_ctx
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipMemsetAsync
(
d_table
,
0
,
N
*
D
*
sizeof
(
T
),
dev_ctx
.
stream
()));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemsetAsync
(
d_table
,
0
,
N
*
D
*
sizeof
(
T
),
dev_ctx
.
stream
()));
#endif
LookupTableV2Grad
<
T
,
IdT
,
128
,
8
,
8
><<<
grids
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
const
int
gridx
=
2
*
dev_ctx
.
GetSMCount
();
dim3
threads
(
128
,
8
);
dim3
grids
(
gridx
,
1
);
LookupTableV2Grad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
...
...
paddle/fluid/platform/device/gpu/gpu_primitives.h
浏览文件 @
d6038c22
...
...
@@ -147,6 +147,94 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
}
}
#endif
// The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )"
// is good. So for fp16 type, we can use "atomicAdd(half2* )" to speed up.
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_same
<
platform
::
float16
,
T
>
::
value
>::
type
*
=
nullptr
>
__device__
__forceinline__
void
fastAtomicAdd
(
T
*
tensor
,
size_t
index
,
const
size_t
numel
,
T
value
)
{
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
CudaAtomicAdd
(
reinterpret_cast
<
platform
::
float16
*>
(
tensor
)
+
index
,
static_cast
<
platform
::
float16
>
(
value
));
#else
// whether the address is 32-byte aligned.
__half
*
target_addr
=
reinterpret_cast
<
__half
*>
(
tensor
+
index
);
bool
aligned_half2
=
(
reinterpret_cast
<
std
::
uintptr_t
>
(
target_addr
)
%
sizeof
(
__half2
)
==
0
);
if
(
aligned_half2
&&
index
<
(
numel
-
1
))
{
__half2
value2
;
value2
.
x
=
*
reinterpret_cast
<
__half
*>
(
&
value
);
value2
.
y
=
__int2half_rz
(
0
);
atomicAdd
(
reinterpret_cast
<
__half2
*>
(
target_addr
),
value2
);
}
else
if
(
!
aligned_half2
&&
index
>
0
)
{
__half2
value2
;
value2
.
x
=
__int2half_rz
(
0
);
value2
.
y
=
*
reinterpret_cast
<
__half
*>
(
&
value
);
atomicAdd
(
reinterpret_cast
<
__half2
*>
(
target_addr
-
1
),
value2
);
}
else
{
atomicAdd
(
reinterpret_cast
<
__half
*>
(
tensor
)
+
index
,
*
reinterpret_cast
<
__half
*>
(
&
value
));
}
#endif
}
template
<
typename
T
,
typename
std
::
enable_if
<!
std
::
is_same
<
platform
::
float16
,
T
>
::
value
>::
type
*
=
nullptr
>
__device__
__forceinline__
void
fastAtomicAdd
(
T
*
arr
,
size_t
index
,
const
size_t
numel
,
T
value
)
{
CudaAtomicAdd
(
arr
+
index
,
value
);
}
#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
* @in: [x1, x2, x3, ...]
* @out:[y1+x1, y2+x2, y3+x3, ...]
* */
template
<
typename
T
,
typename
std
::
enable_if
<!
std
::
is_same
<
platform
::
float16
,
T
>
::
value
>::
type
*
=
nullptr
>
__device__
__forceinline__
void
VectorizedAtomicAddPerBlock
(
const
int64_t
len
,
int
tid
,
int
threads_per_block
,
const
T
*
in
,
T
*
out
)
{
for
(
int
i
=
tid
;
i
<
len
;
i
+=
threads_per_block
)
{
CudaAtomicAdd
(
&
out
[
i
],
in
[
i
]);
}
}
// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_same
<
platform
::
float16
,
T
>
::
value
>::
type
*
=
nullptr
>
__device__
__forceinline__
void
VectorizedAtomicAddPerBlock
(
const
int64_t
len
,
int
tid
,
int
threads_per_block
,
const
T
*
in
,
T
*
out
)
{
int
i
=
0
;
int
loops
=
len
/
2
*
2
;
bool
aligned_half2
=
(
reinterpret_cast
<
std
::
uintptr_t
>
(
out
)
%
sizeof
(
__half2
)
==
0
);
if
(
aligned_half2
)
{
for
(
i
=
tid
*
2
;
i
<
loops
;
i
+=
threads_per_block
*
2
)
{
__half2
value2
;
T
value_1
=
in
[
i
];
T
value_2
=
in
[
i
+
1
];
value2
.
x
=
*
reinterpret_cast
<
__half
*>
(
&
value_1
);
value2
.
y
=
*
reinterpret_cast
<
__half
*>
(
&
value_2
);
atomicAdd
(
reinterpret_cast
<
__half2
*>
(
&
out
[
i
]),
value2
);
}
for
(;
i
<
len
;
i
+=
threads_per_block
)
{
fastAtomicAdd
(
out
,
i
,
len
,
in
[
i
]);
}
}
else
{
for
(
int
i
=
tid
;
i
<
len
;
i
+=
threads_per_block
)
{
fastAtomicAdd
(
out
,
i
,
len
,
in
[
i
]);
}
}
}
#endif
#endif
CUDA_ATOMIC_WRAPPER
(
Add
,
complex
<
float
>
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录