Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7ba14d74
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7ba14d74
编写于
3月 19, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix lookup speed error
上级
755ad257
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
29 addition
and
27 deletion
+29
-27
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+19
-13
paddle/phi/kernels/gpu/embedding_kernel.cu
paddle/phi/kernels/gpu/embedding_kernel.cu
+10
-14
未找到文件。
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
7ba14d74
...
@@ -35,7 +35,7 @@ __global__ void InputTypeConvert(const InT* in_ids,
...
@@ -35,7 +35,7 @@ __global__ void InputTypeConvert(const InT* in_ids,
}
}
}
}
template
<
typename
T
,
typename
IdT
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
template
<
typename
T
,
typename
IdT
>
__global__
void
LookupTableV2Grad
(
T
*
table
,
__global__
void
LookupTableV2Grad
(
T
*
table
,
const
T
*
output
,
const
T
*
output
,
const
IdT
*
ids
,
const
IdT
*
ids
,
...
@@ -43,16 +43,20 @@ __global__ void LookupTableV2Grad(T* table,
...
@@ -43,16 +43,20 @@ __global__ void LookupTableV2Grad(T* table,
const
int64_t
K
,
const
int64_t
K
,
const
int64_t
D
)
{
const
int64_t
D
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
gridDim
.
x
;
while
(
idy
<
K
)
{
while
(
idy
<
K
)
{
auto
id
=
static_cast
<
int64_t
>
(
ids
[
idy
]);
auto
id
=
static_cast
<
int64_t
>
(
ids
[
idy
]);
const
T
*
out
=
output
+
idy
*
D
;
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
*
D
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
#ifdef PADDLE_WITH_CUDA
paddle
::
platform
::
VectorizedAtomicAddPerBlock
(
D
,
idx
,
blockDim
.
x
,
out
,
tab
);
#else
for
(
int
i
=
idx
;
i
<
D
;
i
+=
blockDim
.
x
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
}
}
idy
+=
BlockDimY
*
GridDimX
;
#endif
idy
+=
blockDim
.
y
*
gridDim
.
x
;
}
}
}
}
...
@@ -83,20 +87,22 @@ struct LookupTableV2GradCUDAFunctor {
...
@@ -83,20 +87,22 @@ struct LookupTableV2GradCUDAFunctor {
int
D
=
weight_grad_
->
dims
()[
1
];
int
D
=
weight_grad_
->
dims
()[
1
];
int
K
=
input_
.
numel
();
int
K
=
input_
.
numel
();
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
const
T
*
d_output
=
d_output_t
.
template
data
<
T
>();
const
T
*
d_output
=
d_output_t
.
template
data
<
T
>();
const
auto
*
ids
=
input_
.
template
data
<
IdT
>();
const
auto
*
ids
=
input_
.
template
data
<
IdT
>();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
dev_ctx_
.
GetPlace
());
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
dev_ctx_
.
GetPlace
());
auto
t
=
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
#ifdef PADDLE_WITH_HIP
t
.
device
(
*
dev_ctx_
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
PADDLE_ENFORCE_GPU_SUCCESS
(
hipMemsetAsync
(
d_table
,
0
,
N
*
D
*
sizeof
(
T
),
dev_ctx_
.
stream
()));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemsetAsync
(
d_table
,
0
,
N
*
D
*
sizeof
(
T
),
dev_ctx_
.
stream
()));
#endif
LookupTableV2Grad
<
T
,
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
IdT
,
dim3
threads
(
128
,
8
);
128
,
dim3
grids
(
gridx
,
1
);
8
,
LookupTableV2Grad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
8
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
}
}
...
...
paddle/phi/kernels/gpu/embedding_kernel.cu
浏览文件 @
7ba14d74
...
@@ -23,12 +23,7 @@
...
@@ -23,12 +23,7 @@
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
template
<
typename
T
,
typename
IdT
,
bool
PaddingFlag
>
typename
IdT
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
,
bool
PaddingFlag
>
__global__
void
LookupTableV2
(
T
*
output
,
__global__
void
LookupTableV2
(
T
*
output
,
const
T
*
table
,
const
T
*
table
,
const
IdT
*
ids
,
const
IdT
*
ids
,
...
@@ -37,13 +32,13 @@ __global__ void LookupTableV2(T *output,
...
@@ -37,13 +32,13 @@ __global__ void LookupTableV2(T *output,
const
int64_t
D
,
const
int64_t
D
,
const
int64_t
padding_idx
)
{
const
int64_t
padding_idx
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
gridDim
.
x
;
while
(
idy
<
K
)
{
while
(
idy
<
K
)
{
auto
id
=
static_cast
<
int64_t
>
(
ids
[
idy
]);
auto
id
=
static_cast
<
int64_t
>
(
ids
[
idy
]);
T
*
out
=
output
+
idy
*
D
;
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
blockDim
.
x
)
{
if
(
PaddingFlag
)
{
if
(
PaddingFlag
)
{
if
(
id
==
padding_idx
)
if
(
id
==
padding_idx
)
out
[
i
]
=
static_cast
<
T
>
(
0
);
out
[
i
]
=
static_cast
<
T
>
(
0
);
...
@@ -53,7 +48,7 @@ __global__ void LookupTableV2(T *output,
...
@@ -53,7 +48,7 @@ __global__ void LookupTableV2(T *output,
out
[
i
]
=
tab
[
i
];
out
[
i
]
=
tab
[
i
];
}
}
}
}
idy
+=
BlockDimY
*
GridDimX
;
idy
+=
blockDim
.
y
*
gridDim
.
x
;
}
}
}
}
...
@@ -76,19 +71,20 @@ struct LookupTableV2CUDAFunctor {
...
@@ -76,19 +71,20 @@ struct LookupTableV2CUDAFunctor {
size_t
D
=
weight_
.
dims
()[
1
];
size_t
D
=
weight_
.
dims
()[
1
];
size_t
K
=
input_
.
numel
();
size_t
K
=
input_
.
numel
();
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
dim3
threads
(
256
,
4
);
dim3
threads
(
256
,
4
);
dim3
grids
(
80
,
1
);
dim3
grids
(
gridx
,
1
);
const
auto
*
table
=
weight_
.
template
data
<
T
>();
const
T
*
table
=
weight_
.
template
data
<
T
>();
const
auto
*
ids
=
input_
.
template
data
<
IdT
>();
const
IdT
*
ids
=
input_
.
template
data
<
IdT
>();
auto
*
output
=
out_
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
());
auto
*
output
=
out_
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
());
auto
stream
=
dev_ctx_
.
stream
();
auto
stream
=
dev_ctx_
.
stream
();
if
(
padding_idx_
==
-
1
)
{
if
(
padding_idx_
==
-
1
)
{
LookupTableV2
<
T
,
IdT
,
256
,
4
,
80
,
false
><<<
grids
,
threads
,
0
,
stream
>>>
(
LookupTableV2
<
T
,
IdT
,
false
><<<
grids
,
threads
,
0
,
stream
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx_
);
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx_
);
}
else
{
}
else
{
LookupTableV2
<
T
,
IdT
,
256
,
4
,
80
,
true
><<<
grids
,
threads
,
0
,
stream
>>>
(
LookupTableV2
<
T
,
IdT
,
true
><<<
grids
,
threads
,
0
,
stream
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx_
);
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx_
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录