Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1f45b313
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2301
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1f45b313
编写于
4月 25, 2023
作者:
shaojie_wang
提交者:
GitHub
4月 26, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherry pick dev branch for embedding grad (#53332)
上级
3f2f4040
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
129 addition
and
11 deletion
+129
-11
paddle/phi/core/flags.cc
paddle/phi/core/flags.cc
+16
-0
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+113
-11
未找到文件。
paddle/phi/core/flags.cc
浏览文件 @
1f45b313
...
...
@@ -232,6 +232,22 @@ PADDLE_DEFINE_EXPORTED_bool(
"operator. The autotuning algorithm may be non-deterministic. If "
"true, the algorithm is deterministic."
);
/**
* CUDA related FLAG
* Name: FLAGS_embedding_deterministic
* Since Version: 2.5
* Value Range: bool, default=false
* Example:
* Note: whether to use deterministic algorithm in embedding op.
* If true, it will use deterministic CUDA kernel in embedding op.
*/
PADDLE_DEFINE_EXPORTED_bool
(
embedding_deterministic
,
false
,
"Whether allow using an deterministic algorithm for embedding "
"operator. The deterministic algorithm may be slower. If "
"true, the algorithm is deterministic."
);
/**
* CUDNN related FLAG
* Name: FLAGS_conv_workspace_size_limit
...
...
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
1f45b313
...
...
@@ -18,6 +18,7 @@
#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
...
...
@@ -25,10 +26,20 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool
(
cudnn
_deterministic
);
DECLARE_bool
(
embedding
_deterministic
);
namespace
phi
{
#ifdef PADDLE_WITH_HIP
#define WARP_SIZE 64
#define BLOCKDIMY 16
#else
#define WARP_SIZE 32
#define BLOCKDIMY 32
#endif
#define MASK 0xffffffff
template
<
typename
InT
,
typename
OutT
>
__global__
void
InputTypeConvert
(
const
InT
*
in_ids
,
const
int64_t
K
,
...
...
@@ -63,6 +74,91 @@ __global__ void EmbeddingGrad(T* table,
}
}
template
<
typename
T
,
typename
IdT
>
__global__
void
EmbeddingGradDeterministic
(
T
*
table
,
const
T
*
output
,
const
IdT
*
ids
,
const
IdT
K
,
const
IdT
D
)
{
using
MT
=
typename
dtype
::
MPTypeTrait
<
T
>::
Type
;
extern
__shared__
char
buf
[];
MT
*
smem
=
reinterpret_cast
<
MT
*>
(
buf
);
MT
*
my_s
=
smem
+
WARP_SIZE
*
threadIdx
.
y
;
IdT
*
indices_batch
=
reinterpret_cast
<
IdT
*>
(
buf
+
sizeof
(
MT
)
*
WARP_SIZE
*
BLOCKDIMY
);
const
int
stride
=
static_cast
<
int
>
(
D
);
const
int
feature
=
threadIdx
.
x
+
blockIdx
.
x
*
WARP_SIZE
;
// To ensure determinism. If any other warps pulled grad data targeting
// dst_row, we elect the first warp in each matching group as the leader.
// Each leader warp serializes the accumulates targeting dst_row in shared
// memory, then adding the accumulated buffer to dst_row in table.
for
(
int
batch_start
=
0
;
batch_start
<
K
;
batch_start
+=
WARP_SIZE
*
BLOCKDIMY
)
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
WARP_SIZE
;
if
(
batch_start
+
tid
<
K
)
indices_batch
[
tid
]
=
static_cast
<
IdT
>
(
ids
[
batch_start
+
tid
]);
int
batch_end
=
min
(
static_cast
<
IdT
>
(
batch_start
+
WARP_SIZE
*
BLOCKDIMY
),
K
);
// Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY
for
(
int
chunk_start
=
batch_start
;
chunk_start
<
batch_end
;
chunk_start
+=
BLOCKDIMY
)
{
// This sync makes sure that indices_batch is ready and match-group
// leaders are done with their accumulates before other warps start
// loading again.
__syncthreads
();
int
n_this_chunk
=
min
(
batch_end
-
chunk_start
,
BLOCKDIMY
);
IdT
src_row
=
static_cast
<
IdT
>
(
chunk_start
+
threadIdx
.
y
);
IdT
dst_row
=
indices_batch
[
src_row
-
batch_start
];
if
(
src_row
<
K
&&
feature
<
stride
)
my_s
[
threadIdx
.
x
]
=
static_cast
<
MT
>
(
output
[
src_row
*
D
+
feature
]);
__syncthreads
();
if
(
src_row
<
K
)
{
int
match_found_this_thread
=
0
;
if
(
threadIdx
.
x
<
n_this_chunk
)
{
match_found_this_thread
=
(
dst_row
==
indices_batch
[
chunk_start
-
batch_start
+
threadIdx
.
x
]);
}
#ifdef PADDLE_WITH_HIP
unsigned
long
long
int
matchmask
=
// NOLINT
__ballot
(
match_found_this_thread
);
// NOLINT
int
first_remaining_peer
=
__ffsll
(
matchmask
)
-
1
;
#else
// If and only if match_found_this_thread of the Nth thread is non-zero,
// set the Nth bit of matchmask to 1.
unsigned
int
matchmask
=
__ballot_sync
(
MASK
,
match_found_this_thread
);
// Find the position of the first bit set to 1 in matchmask.
int
first_remaining_peer
=
__ffs
(
matchmask
)
-
1
;
#endif
// select lowest-indexed warp as the leader
if
(
threadIdx
.
y
==
first_remaining_peer
)
{
// Set the first bit 1 in matchmask to 0.
matchmask
^=
(
1
<<
first_remaining_peer
);
while
(
matchmask
)
{
#ifdef PADDLE_WITH_HIP
first_remaining_peer
=
__ffsll
(
matchmask
)
-
1
;
#else
first_remaining_peer
=
__ffs
(
matchmask
)
-
1
;
#endif
my_s
[
threadIdx
.
x
]
+=
smem
[
threadIdx
.
x
+
WARP_SIZE
*
first_remaining_peer
];
matchmask
^=
(
1
<<
first_remaining_peer
);
}
if
(
feature
<
stride
)
table
[
dst_row
*
D
+
feature
]
+=
static_cast
<
T
>
(
my_s
[
threadIdx
.
x
]);
}
}
}
}
}
template
<
typename
T
,
typename
Context
>
struct
EmbeddingGradCUDAFunctor
{
EmbeddingGradCUDAFunctor
(
const
Context
&
dev_ctx
,
...
...
@@ -102,17 +198,23 @@ struct EmbeddingGradCUDAFunctor {
cudaMemsetAsync
(
d_table
,
0
,
N
*
D
*
sizeof
(
T
),
dev_ctx_
.
stream
()));
#endif
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
dim3
threads
(
128
,
8
);
dim3
grids
(
gridx
,
1
);
if
(
FLAGS_cudnn_deterministic
)
{
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
grids
.
x
=
1
;
threads
.
y
=
1
;
if
(
FLAGS_embedding_deterministic
)
{
dim3
threads
(
WARP_SIZE
,
BLOCKDIMY
);
dim3
grids
(
static_cast
<
int
>
((
D
+
WARP_SIZE
-
1
)
/
WARP_SIZE
));
using
MT
=
typename
dtype
::
MPTypeTrait
<
T
>::
Type
;
EmbeddingGradDeterministic
<
T
,
IdT
>
<<<
grids
,
threads
,
sizeof
(
MT
)
*
WARP_SIZE
*
BLOCKDIMY
+
sizeof
(
IdT
)
*
WARP_SIZE
*
BLOCKDIMY
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
K
,
D
);
}
else
{
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
dim3
threads
(
128
,
8
);
dim3
grids
(
gridx
,
1
);
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录